import time
from typing import Dict, Any, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import gymnasium as gym
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from benchrl.algorithms.sac import SAC
from benchrl.utils._utils import DiagonalGaussian
from benchrl.utils._functions import linear_schedule

class PBSAC(SAC):
    """PAC-Bayes Soft Actor-Critic algorithm V2 - Enhanced version.
    
    Extends the base SAC algorithm with improved PAC-Bayesian guarantees for policy performance.
    This version includes enhanced importance sampling, better posterior optimization,
    and more robust PAC-Bayes bound computation.
    
    Key improvements over V1:
    - Importance-weighted empirical return evaluation
    - REINFORCE-style posterior updates
    - Better numerical stability
    - Enhanced rollout collection with log probabilities
    """
    
    def __init__(
        self,
        env,
        algo_config: Dict[str, Any],
        device: str = "auto",
        writer: Optional[SummaryWriter] = None,
    ):
        """Initialize PB-SAC V2 algorithm.
        
        Args:
            env: Vectorized environment
            algo_config: Algorithm configuration including PAC-Bayes specific parameters
            device: Device for computation
            writer: Tensorboard writer
        """
        # Initialize base SAC algorithm
        super().__init__(env, algo_config, device, writer)
        
        # PAC-Bayes specific parameters - V2 defaults
        self.pac_bayes_active = algo_config.get('pac_bayes_active', True)
        self.num_pb_envs = algo_config.get('num_pb_envs', 10)
        self.pb_update_freq = algo_config.get('pb_update_freq', 5000)
        self.beta = algo_config.get('beta', 1e-3)
        self.reg_weight = algo_config.get('reg_weight', 0.01)
        self.delta = algo_config.get('delta', 0.05)
        self.r_max_estimate = algo_config.get('r_max_estimate', 1.0)
        self.use_thompson_sampling = algo_config.get('use_thompson_sampling', False)
        self.exploration_probability = algo_config.get('exploration_probability', 0.6)
        self.thompson_sampling_freq = algo_config.get('thompson_sampling_freq', 1000)
        self.pb_rollout_trajectories = algo_config.get('pb_rollout_trajectories', 1000)
        self.pb_rollout_steps = algo_config.get('pb_rollout_steps', 100)
        self.pb_policy_samples = algo_config.get('pb_policy_samples', 8)
        self.pb_update_epochs = algo_config.get('pb_update_epochs', 25)
        self.pb_reset_prior_freq = algo_config.get('pb_reset_prior_freq', 50)
        self.pb_prior_decay_duration = algo_config.get('pb_prior_decay_duration', 0.9)
        self.pb_prior_decay_end = algo_config.get('pb_prior_decay_end', 0.05)
        self.max_grad_norm = algo_config.get('max_grad_norm', 5.0)
        self.pb_initial_std = algo_config.get('pb_initial_std', 0.01)
        
        # Enhanced exploration parameters
        self.exploration_bonus_coeff = algo_config.get('exploration_bonus_coeff', 0.1)
        self.exploration_samples = algo_config.get('exploration_samples', 5)
        self.target_samples = algo_config.get('target_samples', 3)
        self.use_alternating_optimization = algo_config.get('use_alternating_optimization', True)
        
        # Actor freezing for stability
        self.actor_freeze_steps = algo_config.get('actor_freeze_steps', 2000)
        self.adaptation_exploration_samples = algo_config.get('adaptation_exploration_samples', 8)
        self.actor_frozen_until = 0  # Step when actor unfreezes
        
        # PAC-Bayes state variables
        self.max_ep_length = algo_config.get('max_ep_length', 0)
        self.episode_count = 0
        self.r_max = self.r_max_estimate
        self.mixing_time = algo_config.get('mixing_time', 100) # Initial guess
        self.last_pac_bayes_bound = 0.0
        self.replay_pb = []  # Storage for PAC-Bayes rollouts
        
        self.last_actor_loss = 0.0  # Store last actor loss for frozen periods
        
        # Initialize PAC-Bayes distributions if active
        if self.pac_bayes_active:
            # Extract policy parameters
            policy_params = nn.utils.parameters_to_vector(self.policy.actor_network.parameters())
            param_size = policy_params.numel()
            
            # Initialize posterior
            self.posterior_mean = torch.nn.Parameter(policy_params.clone().detach().requires_grad_(True))
            posterior_log_std_tensor = torch.ones(param_size, device=self.device) * math.log(self.pb_initial_std)
            self.posterior_log_std = torch.nn.Parameter(posterior_log_std_tensor.requires_grad_(True))

            # Initialize prior (not optimized)
            self.prior_mean = self.posterior_mean.clone().detach()
            if self.is_discrete:
                prior_log_std_tensor = torch.ones(param_size, device=self.device) * math.log(self.pb_initial_std * 10)
            else:
                prior_log_std_tensor = torch.ones(param_size, device=self.device) * math.log(0.1)
            self.prior_log_std = torch.nn.Parameter(prior_log_std_tensor.requires_grad_(False))
            
            # Create parameter distributions
            self.posterior = DiagonalGaussian(self.posterior_mean, self.posterior_log_std)
            self.prior = DiagonalGaussian(self.prior_mean, self.prior_log_std)
            
            # Initialisation of kappa with a random positive value
            kappa_tensor = torch.empty(1).to(self.device).requires_grad_(True)
            # Apply the uniform distribution (strictly positive, no upper bound)
            torch.nn.init.uniform_(kappa_tensor, 0.1, 2.0)
            self.pb_kappa = torch.nn.Parameter(kappa_tensor)
            
            # Create optimizers for alternating optimization
            if self.use_alternating_optimization:
                # use different learning rates for mean and std
                self.posterior_optimizer = optim.Adam(
                    [self.posterior_mean, self.posterior_log_std],
                    lr=self.beta
                )
                self.pb_kappa_optimizer = optim.Adam([self.pb_kappa], lr=1e-3)
            else:
                # Joint optimization (original approach)
                self.posterior_optimizer = optim.Adam(
                    [self.posterior_mean, self.posterior_log_std, self.pb_kappa],
                    lr=self.beta
                )
        
        # Create separate environment for PAC-Bayes rollouts
        from benchrl.environments.registry import get_env_builder
        env_builder = get_env_builder()
        
        # Build evaluation environment with specified number of envs
        pb_env_wrappers = algo_config.get('pb_env_wrappers', None)
        if pb_env_wrappers is not None:
            for i in range(len(pb_env_wrappers)):
                if '_target_wrapper' in pb_env_wrappers[i]:
                    pb_env_wrappers[i]['_target_'] = pb_env_wrappers[i].pop('_target_wrapper')
        pb_env_config = {
            'env_id': env.envs[0].spec.id,
            'num_envs': self.num_pb_envs,
            'seed': 42,
            'capture_video': False,
            'video_folder': None,
            'wrappers': pb_env_wrappers,
        }
        
        try:
            self.pb_env = env_builder.build_env(**pb_env_config)
        except Exception as e:
            print(f"Warning: Could not create separate PB environment: {e}")
            self.pb_env = self.env  # Fallback to main environment
    
    def _flatten_policy_params(self):
        """Extract and flatten policy parameters."""
        return nn.utils.parameters_to_vector(self.policy.actor_network.parameters())
        
    def _load_policy_params(self, params):
        """Load parameters into policy network."""
        nn.utils.vector_to_parameters(params, self.policy.actor_network.parameters())
    
    def _frozen_policy_rollout(self):
        """
        Collect trajectories with the current frozen actor weights.
        Returns two separate sets: 90% for training (limited by pb_rollout_steps) 
        and 10% for testing (full trajectories).
        Crucially, we also compute and store the log_prob of the actions taken,
        which is needed for importance sampling later.
        """
        num_envs = self.pb_env.num_envs
        total_trajectories = self.pb_rollout_trajectories
        
        # Calculate splits: 90% for training, 10% for testing
        train_trajectories = int(total_trajectories * 0.9)
        test_trajectories = total_trajectories - train_trajectories
        
        replay_pb_update = []
        replay_pb_test = []
        
        # Collect training trajectories (limited by pb_rollout_steps)
        print("Collecting PAC-Bayes training rollouts...")
        train_batches_needed = (train_trajectories + num_envs - 1) // num_envs
        for _ in tqdm(range(train_batches_needed)):
            states, _ = self.pb_env.reset()
            env_trajectories = [[] for _ in range(num_envs)]
            active_envs = [True] * num_envs
            
            while any(active_envs):
                with torch.no_grad():
                    # Use deterministic actions for evaluation rollouts
                    obs_tensor = torch.tensor(states, dtype=torch.float32, device=self.device)
                    actions, log_probs, _ = self.policy.get_action(obs_tensor, deterministic=False)
                    actions = actions.detach().cpu().numpy()
                    log_probs = log_probs.detach().cpu().numpy()
                
                next_states, rewards, dones, truncateds, _ = self.pb_env.step(actions)

                self.r_max = max(self.r_max, np.max(np.abs(rewards)))

                for i in range(num_envs):
                    if not active_envs[i]:
                        continue

                    is_done = dones[i] or truncateds[i] or len(env_trajectories[i]) >= self.pb_rollout_steps
                    
                    env_trajectories[i].append({
                        'state': states[i],
                        'action': actions[i],
                        'reward': rewards[i],
                        'log_prob': log_probs[i], # Store log_prob of the action
                    })
                    
                    if is_done:
                        active_envs[i] = False
                
                states = next_states
            
            trajectories_needed = min(num_envs, train_trajectories - len(replay_pb_update))
            replay_pb_update.extend(env_trajectories[:trajectories_needed])
        
        # Collect test trajectories (full trajectories, no pb_rollout_steps limit)
        print("Collecting PAC-Bayes test rollouts (full trajectories)...")
        test_batches_needed = (test_trajectories + num_envs - 1) // num_envs
        for _ in tqdm(range(test_batches_needed)):
            states, _ = self.pb_env.reset()
            env_trajectories = [[] for _ in range(num_envs)]
            active_envs = [True] * num_envs
            
            while any(active_envs):
                with torch.no_grad():
                    # Use deterministic actions for evaluation rollouts
                    obs_tensor = torch.tensor(states, dtype=torch.float32, device=self.device)
                    actions, log_probs, _ = self.policy.get_action(obs_tensor, deterministic=False)
                    actions = actions.detach().cpu().numpy()
                    log_probs = log_probs.detach().cpu().numpy()
                
                next_states, rewards, dones, truncateds, _ = self.pb_env.step(actions)
                
                self.r_max = max(self.r_max, np.max(np.abs(rewards)))
                
                for i in range(num_envs):
                    if not active_envs[i]:
                        continue

                    # For test trajectories, only stop on natural episode termination
                    is_done = dones[i] or truncateds[i]
                    
                    env_trajectories[i].append({
                        'state': states[i],
                        'action': actions[i],
                        'reward': rewards[i],
                        'log_prob': log_probs[i], # Store log_prob of the action
                    })
                    
                    if is_done:
                        active_envs[i] = False
                
                states = next_states
            
            trajectories_needed = min(num_envs, test_trajectories - len(replay_pb_test))
            replay_pb_test.extend(env_trajectories[:trajectories_needed])
        
        # Store for backward compatibility
        self.replay_pb = replay_pb_update + replay_pb_test
        
        return replay_pb_update, replay_pb_test

    def estimate_mixing_time(self, replay_pb):
        """Estimate mixing time from trajectory autocorrelations."""
        mt = 1
        for episode in replay_pb:
            rewards = np.array([step['reward'] for step in episode])
            if len(rewards) < 10: continue

            rewards = rewards - rewards.mean()

            # Skip if rewards have zero variance (constant rewards)
            if np.var(rewards) < 1e-8:
                continue

            n = len(rewards)
            acf = np.correlate(rewards, rewards, mode='full')

            # Check if autocorrelation at lag 0 is non-zero
            if abs(acf[n-1]) < 1e-8:
                continue

            acf = acf[n-1:] / acf[n-1]

            tmp_mt = next((i for i, val in enumerate(acf) if val < 0.2), len(acf) -1)
            mt = max(mt, int(tmp_mt * 1.5) if tmp_mt > 1 else 1)

        return mt

    def evaluate_policy_on_trajectories(self, policy_params, replay_pb, precomputed_returns):
        """
        Evaluate a given policy using weighted importance sampling.
        This is the numerically stable, corrected version.
        """
        current_params = self._flatten_policy_params().clone()
        self._load_policy_params(policy_params)

        log_rhos = []
        for traj in replay_pb:
            states = torch.FloatTensor(np.array([s['state'] for s in traj])).to(self.device)
            # Convert actions to appropriate tensor type based on action space
            actions_np = np.array([s['action'] for s in traj])
            if self.is_discrete:
                actions = torch.LongTensor(actions_np).to(self.device)
            else:
                actions = torch.FloatTensor(actions_np).to(self.device)
            with torch.no_grad():
                log_prob_pi = self.policy.get_log_prob(states, actions).squeeze()
            
            log_prob_b = torch.FloatTensor(np.array([s['log_prob'] for s in traj])).to(self.device)
            # Sum log probs to get log of trajectory probability ratio
            log_rho = torch.sum(log_prob_pi - log_prob_b)
            log_rhos.append(log_rho)

        log_rhos_tensor = torch.stack(log_rhos)

        # Stabilize weights by subtracting the max log_rho before exp.
        # This prevents exp(large_number) from becoming inf.
        # stabilized_log_rhos = log_rhos_tensor - torch.max(log_rhos_tensor)
        # weights = torch.exp(stabilized_log_rhos)
        max_log_rho = torch.abs(log_rhos_tensor).max()
        magnitude = 10 ** torch.floor(torch.log10(torch.clamp(max_log_rho, min=1e-8)))
        weights = torch.softmax(log_rhos_tensor / torch.clamp(magnitude, min=1e-8), dim=0)

        # Normalize weights for proper importance sampling
        # weights = weights / torch.sum(weights)
        # Compute the weighted average of returns. This works for negative returns.
        estimated_return = torch.sum(weights * precomputed_returns)

        self._load_policy_params(current_params) # Restore original actor
        return estimated_return

    def update_pac_bayes_components(self, replay_pb):
        """
        Update the PAC-Bayes posterior distribution using alternating optimization
        between posterior parameters and kappa.
        """
        if not self.pac_bayes_active: return {}

        print("Optimizing PAC-Bayes posterior with alternating optimization...")

        # --- Phase 1: Optimize Posterior given fixed Kappa ---
        fixed_kappa = self.pb_kappa.detach()
        H = self.max_ep_length
        T = max(self.episode_count, 1)
        # adjust T to account for additional trajectories in replay buffer
        T = T + (len(replay_pb) * (self.global_step // self.pb_update_freq))
        c_squared = max(1e-6, self.r_max ** 2 * (1 - self.gamma ** (2 * H)) / (T * (1 - self.gamma ** 2)))

        for epoch in range(self.pb_update_epochs):
            # Sample policies from the current posterior to estimate the expected return
            policy_samples = [self.posterior.sample() for _ in range(self.pb_policy_samples)]
            # sample 0.8 of trajectories from replay_pb
            num_sampled_trajectories = max(1, int(0.8 * len(replay_pb)))
            sampled_indices = np.random.choice(len(replay_pb), num_sampled_trajectories, replace=False)
            sampled_trajectories = [replay_pb[i] for i in sampled_indices]

            # Pre-calculate discounted returns for each trajectory ONCE
            discounted_returns = torch.tensor([
                sum(s['reward'] * (self.gamma ** i) for i, s in enumerate(traj))
                for traj in sampled_trajectories
            ], dtype=torch.float32, device=self.device)

            # Evaluate policies using importance sampling
            estimated_returns = torch.stack([
                self.evaluate_policy_on_trajectories(params, sampled_trajectories, discounted_returns)
                for params in policy_samples
            ])

            advantages = estimated_returns - estimated_returns.mean()

            self.posterior_optimizer.zero_grad()

            # The objective is to maximize rewards regularized by KL-divergence
            log_probs = torch.stack([self.posterior.log_prob(theta) for theta in policy_samples])
            policy_loss = -(advantages * log_probs.unsqueeze(1)).mean()
            kl_div = self.posterior.kl_divergence(self.prior)

            # New loss formulation with kappa:
            # Loss = policy_loss + (KL(ρ||μ) + ln(2/δ))/κ + κ||c||²τ_min/8
            kl_term = (kl_div + math.log(2 / self.delta)) / fixed_kappa
            regularization_term = (fixed_kappa * c_squared * self.mixing_time) / 8.0

            # The loss for the posterior parameters (μ, σ)
            posterior_loss = policy_loss + self.reg_weight * (kl_term + regularization_term)

            if self.is_discrete:
                posterior_loss.backward(retain_graph=True)
            else:
                posterior_loss.backward(retain_graph=(epoch < self.pb_update_epochs - 1))
            torch.nn.utils.clip_grad_norm_([self.posterior_mean, self.posterior_log_std], max_norm=self.max_grad_norm)
            self.posterior_optimizer.step()

            if epoch % max(self.pb_update_epochs // 5, 1) == 0:
                print(f"  Posterior Epoch {epoch}: Loss={posterior_loss.item():.4f}, KL={kl_div.item():.4f}, Return={estimated_returns.mean().item():.4f}")

        # Clean up graph from the last iteration of the loop
        self.posterior_optimizer.zero_grad()


        # --- Phase 2: Optimize Kappa given fixed Posterior ---
        with torch.no_grad():
            # Recalculate final KL divergence with the updated posterior
            final_kl_div = self.posterior.kl_divergence(self.prior)
            # Detach policy loss from graph as it's constant wrt kappa
            final_policy_loss = policy_loss.detach()

        # New kappa loss formulation:
        # Loss = final_policy_loss + (KL(ρ||μ) + ln(2/δ))/κ + κ||c||²τ_min/8
        kl_term = (final_kl_div + math.log(2 / self.delta)) / self.pb_kappa
        regularization_term = (self.pb_kappa * c_squared * self.mixing_time) / 8.0

        # This is the loss for the kappa parameter
        kappa_loss = final_policy_loss + kl_term + regularization_term

        self.pb_kappa_optimizer.zero_grad()
        kappa_loss.backward()
        self.pb_kappa_optimizer.step()


        print(f"  Kappa optimization: Loss={kappa_loss.item():.4f}, Kappa={self.pb_kappa.item():.4f}, Return={estimated_returns.mean().item():.4f}")

        final_pb_loss = kappa_loss.item()
        final_kl = final_kl_div.item()

        # Clamp posterior log std to prevent it from collapsing or exploding
        with torch.no_grad():
            self.posterior_log_std.data.clamp_(min=math.log(1e-4), max=math.log(0.5))
            # Ensure kappa remains strictly positive (no upper bound)
            if not self.use_alternating_optimization:
                self.pb_kappa.data = self.pb_kappa.data.clamp(min=1e-6)
            else:
                self.pb_kappa.data = self.pb_kappa.data.clamp(min=1e-6)

        self.posterior = DiagonalGaussian(self.posterior_mean, self.posterior_log_std)

        # Delete samples and clear cache to free GPU memory (after all gradients computed)
        del policy_samples, log_probs
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return {
            'pb_loss': final_pb_loss,
            'kl_div': final_kl,
            'mean_empirical_return': estimated_returns.mean().item(),
            'posterior_std': torch.exp(self.posterior_log_std).mean().item(),
            'kappa': self.pb_kappa.item()
        }
        
    def compute_pac_bayes_bound(self, replay_pb):
        """
        Compute the full PAC-Bayes certified lower bound on policy performance.
        This involves estimating the empirical performance of the posterior and subtracting the uncertainty term.
        """
        # Pre-calculate discounted returns for each trajectory ONCE
        discounted_returns = torch.tensor([
            sum(s['reward'] * (self.gamma ** i) for i, s in enumerate(traj))
            for traj in replay_pb
        ], dtype=torch.float32, device=self.device)
        
        if not self.pac_bayes_active:
            return {'certified_return': 0.0, 'empirical_performance': 0.0, 'uncertainty_term': 0.0}

        # 1. Estimate the empirical performance term: E_{theta~rho}[-L_D(theta)]
        # We sample a fresh batch of policies from the *updated* posterior for an unbiased estimate.
        with torch.no_grad():
            policy_samples = [self.posterior.sample() for _ in range(self.pb_policy_samples)]
            empirical_returns = torch.stack([
                self.evaluate_policy_on_trajectories(params, replay_pb, discounted_returns) for params in policy_samples
            ])
            empirical_performance = empirical_returns.mean().item()

            # Delete samples to free GPU memory
            del policy_samples, empirical_returns
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        # 2. Compute the uncertainty term (the square root part)
        H = self.max_ep_length
        T = max(self.episode_count, 1)
        # self.mixing_time = self.estimate_mixing_time(replay_pb)
        kl_div = self.posterior.kl_divergence(self.prior).item()
        
        # Ensure c_squared is finite and positive
        c_squared_numerator = self.r_max ** 2 * (1 - self.gamma ** (2 * H))
        c_squared_denominator = T * (1 - self.gamma ** 2)
        c_squared = max(1e-6, c_squared_numerator / c_squared_denominator)
            
        term_inside_sqrt = 0.5 * c_squared * self.mixing_time * (kl_div + math.log(2 / self.delta))
        uncertainty_term = math.sqrt(term_inside_sqrt)
        
        # 3. Compute the final certified lower bound
        certified_return = empirical_performance - uncertainty_term
        
        bound_info = {
            'certified_return': certified_return,
            'empirical_performance': empirical_performance,
            'uncertainty_term': uncertainty_term,
            'mixing_time': self.mixing_time,
            'kl_div': kl_div,
            'c_squared': c_squared,
            'r_max': self.r_max
        }
        
        print(f"PAC-Bayes Certified Return: {bound_info['certified_return']:.4f} "
              f"(Empirical: {bound_info['empirical_performance']:.4f}, Uncertainty: {bound_info['uncertainty_term']:.4f})")
        
        return bound_info

    def reset_prior(self, decay=0.99):
        """Update prior using EMA of posteriors."""
        if not self.pac_bayes_active:
            return
        
        self.decay = linear_schedule(
            decay, self.pb_prior_decay_end,
            self.pb_prior_decay_duration * self.total_timesteps,
            self.global_step
        )
        
        with torch.no_grad():
            self.prior_mean = (1 - self.decay) * self.prior_mean + self.decay * self.posterior_mean.detach()
            self.prior_log_std.copy_(self.posterior_log_std.detach().data)
            
        self.prior = DiagonalGaussian(self.prior_mean, self.prior_log_std)
        
    def inject_posterior_knowledge(self):
        """Set actor parameters to posterior mean directly (no mixing for stability)."""
        with torch.no_grad():
            self._load_policy_params(self.posterior_mean.detach())
            # Freeze actor for critic adaptation period
            self.actor_frozen_until = self.global_step + self.actor_freeze_steps
            print(f"Actor frozen until step {self.actor_frozen_until} for critic adaptation")
    
    def get_exploration_bonus(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """Compute exploration bonus based on posterior uncertainty.
        
        Uses Q-value disagreement when different posterior samples act on the same states.
        """
        with torch.no_grad():
            # Store current parameters
            current_params = self._flatten_policy_params().clone()
            
            # Sample multiple policies from posterior and get their Q-values
            q_values = []
            for _ in range(self.exploration_samples):
                theta = self.posterior.sample()
                self._load_policy_params(theta)
                
                # Generate NEW actions with this posterior sample
                if self.is_discrete:
                    sampled_actions, _, _ = self.policy.get_action(obs, deterministic=True)
                    q1 = self.policy.critic_network(obs).gather(1, sampled_actions.long()).flatten()
                    q2 = self.policy.critic2_network(obs).gather(1, sampled_actions.long()).flatten()
                else:
                    sampled_actions, _, _ = self.policy.get_action(obs, deterministic=True)
                    q1 = self.policy.critic_network(obs, sampled_actions).flatten()
                    q2 = self.policy.critic2_network(obs, sampled_actions).flatten()
                q_val = torch.min(q1, q2)
                q_values.append(q_val)
            
            # Restore original parameters
            self._load_policy_params(current_params)

            # Compute disagreement (standard deviation) as exploration bonus
            q_values = torch.stack(q_values)  # [exploration_samples, batch_size]
            q_std = q_values.std(dim=0)  # [batch_size]
            exploration_bonus = self.exploration_bonus_coeff * q_std

            # Delete samples to free GPU memory
            del q_values, current_params
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            return exploration_bonus

    def compute_posterior_guided_targets(self, next_obs: torch.Tensor, rewards: torch.Tensor,
                                           dones: torch.Tensor, actor_is_frozen: bool) -> torch.Tensor:
        """Compute target Q-values using posterior distribution samples.
        
        Helps critics learn the full distribution of the actor policy.
        Uses stable sampling around posterior mean instead of full parameter swapping.
        """
        with torch.no_grad():
            target_values = []
            # Get posterior statistics
            posterior_mean = self.posterior_mean.detach()
            posterior_std = torch.exp(self.posterior_log_std.detach())

            if actor_is_frozen:
                num_samples = self.adaptation_exploration_samples
            else:
                num_samples = self.target_samples
            
            for i in range(num_samples):
                if i == 0:
                    # First sample: use posterior mean exactly (most stable)
                    self._load_policy_params(posterior_mean)
                else:
                    # Other samples: add controlled noise around posterior mean
                    # noise = torch.randn_like(posterior_mean) * posterior_std * 0.1  # Small noise scale
                    # perturbed_params = posterior_mean + noise
                    perturbed_params = self.posterior.sample().detach()
                    self._load_policy_params(perturbed_params)
                    
                # Get next actions and compute targets
                if self.is_discrete:
                    _, next_log_probs, next_action_probs = self.policy.get_action(next_obs)
                    qf1_next_target = self.target_critic1(next_obs)
                    qf2_next_target = self.target_critic2(next_obs)
                    min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
                    target_q = (next_action_probs * (min_qf_next_target - self.alpha * next_log_probs.unsqueeze(1))).sum(dim=1)
                else:
                    next_actions, next_log_probs, _ = self.policy.get_action(next_obs)
                    target_q1 = self.target_critic1(next_obs, next_actions).flatten()
                    target_q2 = self.target_critic2(next_obs, next_actions).flatten()
                    target_q = torch.min(target_q1, target_q2) - self.alpha * next_log_probs.flatten()
                
                target_values.append(target_q)
                
            
            # Restore posterior mean (current actor should be posterior mean)
            self._load_policy_params(posterior_mean)
            
            # Use mean of target distribution (helps critics learn full distribution)
            target_values = torch.stack(target_values)  # [target_samples, batch_size]
            # Importance weights for each sample
            ensemble_target = (target_values).mean(dim=0)
            result = rewards + (1 - dones.flatten()) * self.gamma * ensemble_target

            # Delete samples to free GPU memory
            del target_values, posterior_mean
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            return result
    
    def get_ucb_action_selection(self, obs: torch.Tensor) -> torch.Tensor:
        """Use PAC-Bayes lower bound as UCB for action selection."""
        with torch.no_grad():
            # Sample multiple actions from posterior
            candidate_actions = []
            ucb_values = []
            if self.is_discrete:
                qf1_next_target, qf2_next_target = self.policy.get_q_values(obs)
                qf1_discrete_next_target = qf1_next_target.unsqueeze(0)
                qf2_discrete_next_target = qf2_next_target.unsqueeze(0)
            for i in range(self.exploration_samples):
                if i == 0:
                    # First sample: use posterior mean exactly (most stable)
                    self._load_policy_params(self.posterior_mean)
                else:
                    # Other samples: full posterior sample
                    theta = self.posterior.sample()
                    self._load_policy_params(theta)

                # Get action and its Q-value
                action, _, _ = self.policy.get_action(obs, deterministic=False)
                if self.is_discrete:
                    action = action.unsqueeze(0)
                    qf1_next_target = qf1_discrete_next_target.gather(1, action.long()).flatten()
                    qf2_next_target = qf2_discrete_next_target.gather(1, action.long()).flatten()
                    action = action.squeeze(0)
                else :
                    qf1_next_target, qf2_next_target = self.policy.get_q_values(obs, action)
                
                q_value = torch.min(qf1_next_target, qf2_next_target)

                candidate_actions.append(action)
                ucb_values.append(q_value)


            q_values = torch.stack(ucb_values)

            # Select action with highest UCB
            ucb_scores = q_values
            best_idx = ucb_scores.argmax(dim=0)

            # Handle vectorized environments - gather the best action for each env
            candidate_actions_tensor = torch.stack(candidate_actions)  # [num_samples, num_envs, action_dim]
            selected_actions = candidate_actions_tensor[best_idx, torch.arange(best_idx.shape[0])]

            # Delete samples to free GPU memory
            del candidate_actions, ucb_values, q_values, candidate_actions_tensor
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            return selected_actions
      
    def collect_rollouts(self, num_steps: int = 1) -> Dict[str, float]:
        """Collect experience for replay buffer with optional Thompson sampling."""
        if self.last_obs is None:
            self.last_obs, _ = self.env.reset()
        
        episode_returns = []
        episode_lengths = []
        
        # Pre-allocate tensors to avoid repeated allocation
        if not hasattr(self, '_obs_tensor'):
            self._obs_tensor = torch.zeros((self.env.num_envs, *self.env.single_observation_space.shape), 
                                         dtype=torch.float32, device=self.device)
        
        for _ in range(num_steps):
            self.global_step += self.env.num_envs
            
            # Copy observations to pre-allocated tensor
            self._obs_tensor.copy_(torch.from_numpy(self.last_obs))
            
            with torch.no_grad():
                if self.global_step < self.learning_starts:
                    # Vectorized random action sampling
                    if self.is_discrete:
                        actions = np.random.randint(0, self.action_space.n, size=(self.env.num_envs,))
                    else:
                        actions = np.random.uniform(
                            self.action_space.low, self.action_space.high, 
                            size=(self.env.num_envs, self.action_space.shape[0])
                        )
                else:
                    # Thompson sampling or regular action selection
                    if (self.use_thompson_sampling and
                        self.global_step % self.thompson_sampling_freq == 0):
                        # Store current parameters
                        current_params = self._flatten_policy_params().clone()
                        # Load sampled parameters
                        sampled_params = self.posterior.sample().detach()
                        self._load_policy_params(sampled_params)
                        # Get action with sampled policy
                        actions, _, _ = self.policy.get_action(self._obs_tensor, deterministic=False)
                        # Restore original parameters
                        self._load_policy_params(current_params)
                        # Delete samples to free GPU memory
                        del current_params, sampled_params
                    elif np.random.random() < self.exploration_probability:
                        current_params = self._flatten_policy_params().clone()
                        actions = self.get_ucb_action_selection(self._obs_tensor)
                        # Restore original parameters
                        self._load_policy_params(current_params)
                        
                    else:
                        # Regular action selection
                        actions, _, _ = self.policy.get_action(self._obs_tensor, deterministic=False)
                    
                        
                    self.exploration_probability = linear_schedule(
                        0.8, 0.1,
                        0.8 * self.total_timesteps,
                        self.global_step
                    )
                    actions = actions.cpu().numpy()
            
            # Take environment step
            next_obs, rewards, terminations, truncations, infos = self.env.step(actions)

            self.r_max = max(self.r_max, np.max(np.abs(rewards)))

            # TRY NOT TO MODIFY: record rewards for plotting purposes
            rollout_episodes = 0
            if "final_info" in infos:
                eps_final_info = infos["final_info"]['episode']
                rollout_episodes = len(eps_final_info["r"])
                self.episode_count += rollout_episodes
                for i in range(len(eps_final_info['r'])):
                        episode_returns.append(eps_final_info["r"][i])
                        episode_lengths.append(eps_final_info["l"][i])
                        break
            self.max_ep_length = max(self.max_ep_length, max(episode_lengths) if episode_lengths else 0)
            # Final observation handling
            real_next_obs = next_obs.copy()
            for idx, trunc in enumerate(truncations):
                if trunc:
                    real_next_obs[idx] = infos["final_obs"][idx]

            # Store transitions in replay buffer
            self.replay_buffer.add(
                obs=self.last_obs,
                next_obs=real_next_obs,
                action=actions,
                reward=rewards,
                done=terminations,
                infos=infos
            )
            
            # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
            self.last_obs = next_obs
            
        
        # Return rollout metrics
        metrics = {}
        if episode_returns:
            metrics.update({
                'rollout/episodic_return': episode_returns,
                'rollout/episodic_length': episode_lengths,
                'rollout/episodes': rollout_episodes
            })
        
        return metrics
    
    def train_step(self) -> Dict[str, float]:
        """Execute one training step with enhanced PAC-Bayes updates."""
        # Collect experience
        rollout_metrics = self.collect_rollouts(self.train_freq)
        
        # Skip training if not enough samples
        if self.global_step < self.learning_starts:
            return rollout_metrics
        
        # ACTOR training with freezing mechanism
        actor_is_frozen = self.global_step < self.actor_frozen_until
        
        # Sample batch from replay buffer
        data = self.replay_buffer.sample(self.batch_size)

        # Enhanced CRITIC training with posterior-guided targets
        # Compute exploration bonus for logging but don't add to critic training
        # exploration_bonus = self.get_exploration_bonus(data.observations, data.actions)
        
        # Use posterior-guided target computation with original rewards (no exploration bonus)
        # This prevents distribution shift as exploration bonus collapses
        next_q_value = self.compute_posterior_guided_targets(
            data.next_observations, data.rewards.flatten(), data.dones, actor_is_frozen
        )
        
        # Current Q-values
        if self.is_discrete:
            qf1_a_values = self.policy.critic_network(data.observations).gather(1, data.actions.long()).flatten()
            qf2_a_values = self.policy.critic2_network(data.observations).gather(1, data.actions.long()).flatten()
        else:
            qf1_a_values = self.policy.critic_network(data.observations, data.actions).flatten()
            qf2_a_values = self.policy.critic2_network(data.observations, data.actions).flatten()
        
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = qf1_loss + qf2_loss
        
        # Optimize critics
        self.critic_optimizer.zero_grad()
        qf_loss.backward()
        self.critic_optimizer.step()
        
        
        if not actor_is_frozen and self.global_step % self.policy_frequency == 0:
            # Normal actor updates when not frozen
            for _ in range(self.policy_frequency):
                if self.is_discrete:
                    _, log_pi, action_probs = self.policy.get_action(data.observations)
                    with torch.no_grad():
                        qf1_values = self.policy.critic_network(data.observations)
                        qf2_values = self.policy.critic2_network(data.observations)
                        min_qf_values = torch.min(qf1_values, qf2_values)
                    actor_loss = (action_probs * ((self.alpha * log_pi.unsqueeze(1)) - min_qf_values)).mean()
                    if self.ent_coeff == 'auto':
                        alpha_loss = (action_probs.detach() * (-self.log_ent_coeff.exp() * (log_pi.unsqueeze(1).detach() + self.target_entropy))).mean()
                else:
                    pi, log_pi, _ = self.policy.get_action(data.observations)
                    qf1_pi = self.policy.critic_network(data.observations, pi).flatten()
                    qf2_pi = self.policy.critic2_network(data.observations, pi).flatten()
                    min_qf_pi = torch.min(qf1_pi, qf2_pi)
                    actor_loss = (self.alpha * log_pi.flatten() - min_qf_pi).mean()
                    if self.ent_coeff == 'auto':
                        alpha_loss = (-self.log_ent_coeff.exp() * (log_pi.detach().flatten() + self.target_entropy)).mean()
                
                # Optimize actor
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()
                self.last_actor_loss = actor_loss.item()
                # Optimize alpha
                if self.ent_coeff == 'auto':
                    self.ent_coeff_optimizer.zero_grad()
                    alpha_loss.backward()
                    self.ent_coeff_optimizer.step()
                    self.alpha = self.log_ent_coeff.exp().item()
        elif actor_is_frozen:
            pass  # Skip actor update when frozen
        
        # Update target networks
        if self.global_step % self.target_update_interval == 0:
            self._update_target_networks()
        
        # Compile training metrics
        all_metrics = rollout_metrics.copy()
        all_metrics.update({
            'train/qf1_values': qf1_a_values.mean().item(),
            'train/qf2_values': qf2_a_values.mean().item(),
            'train/qf1_loss': qf1_loss.item(),
            'train/qf2_loss': qf2_loss.item(),
            'train/critic_loss': qf_loss.item() / 2.0,
            'train/actor_loss': self.last_actor_loss,
            'train/alpha': self.alpha,
            # 'train/exploration_bonus_mean': exploration_bonus.mean().item(),
            # 'train/exploration_bonus_std': exploration_bonus.std().item(),
        })
        
        if self.ent_coeff == 'auto' and 'alpha_loss' in locals():
            all_metrics['train/alpha_loss'] = alpha_loss.item()
        
        # PAC-Bayes specific updates
        if self.pac_bayes_active:
            # Update PAC-Bayes components infrequently
            update_pb = (self.global_step > 0 and self.global_step % self.pb_update_freq == 0 and self.global_step != self.learning_starts)
            if update_pb:
                current_time = time.time()
                print(f"\n--- Performing PAC-Bayes updates at step {self.global_step} ---")
                
                # 0. Synchronize posterior mean with current actor before optimization
                with torch.no_grad():
                    self.posterior_mean.data.copy_(self._flatten_policy_params())
                    print(f"Synchronized posterior mean with current actor parameters")
                
                # 1. Collect fresh data with the frozen policy
                replay_pb_update, replay_pb_test = self._frozen_policy_rollout()
                
                mixing_time = self.estimate_mixing_time(self.replay_pb)
                self.mixing_time = max(self.mixing_time, mixing_time)
                
                if len(replay_pb_update) > 0 and len(replay_pb_test) > 0:
                    # 2. Fit the posterior distribution using the training data
                    pb_info = self.update_pac_bayes_components(replay_pb_update)
                    for k, v in pb_info.items():
                        all_metrics[f"pac_bayes/update/{k}"] = v
            
                    # 3. Compute and log the full certified bound using test data
                    bound_info = self.compute_pac_bayes_bound(replay_pb_test)
                    for k, v in bound_info.items():
                        all_metrics[f"pac_bayes/{k}"] = v
                    
                    # 4.  Inject knowledge
                    self.inject_posterior_knowledge()
                    
                    pb_spent_time = time.time() - current_time
                    print(f"--- PAC-Bayes updates took {pb_spent_time:.2f} seconds ---\n")
                    
                    # Clear data to ensure independence for the next bound computation
                    self.replay_pb = []

            # If PAC-Bayes is active, update posterior mean after policy update
            if hasattr(self, 'posterior_mean'):
                with torch.no_grad():
                    self.posterior_mean.data.copy_(self._flatten_policy_params())
                    self.posterior = DiagonalGaussian(self.posterior_mean, self.posterior_log_std)
            
            # Reset prior
            if self.global_step % self.pb_reset_prior_freq or self.global_step == self.learning_starts:
                self.reset_prior()
            
            # Always log r_max
            all_metrics['pac_bayes/r_max'] = self.r_max
        
        return all_metrics
    
    def save(self, path: str) -> None:
        """Save algorithm state including PAC-Bayes components."""
        checkpoint = {
            'policy_state_dict': self.policy.state_dict(),
            'target_critic1_state_dict': self.target_critic1.state_dict(),
            'target_critic2_state_dict': self.target_critic2.state_dict(),
            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
            'global_step': self.global_step,
            'episode_count': self.episode_count,
            'algo_config': self.algo_config,
            'r_max': self.r_max,
            'mixing_time': self.mixing_time,
            'last_pac_bayes_bound': self.last_pac_bayes_bound
        }
        
        if self.ent_coeff_optimizer is not None:
            checkpoint['ent_coeff_optimizer_state_dict'] = self.ent_coeff_optimizer.state_dict()
            checkpoint['log_ent_coeff'] = self.log_ent_coeff
        
        # Save PAC-Bayes components if active
        if self.pac_bayes_active:
            checkpoint.update({
                'posterior_mean': self.posterior_mean.data,
                'posterior_log_std': self.posterior_log_std.data,
                'prior_mean': self.prior_mean,
                'prior_log_std': self.prior_log_std,
                'pb_kappa': self.pb_kappa.data,
                'posterior_optimizer_state_dict': self.posterior_optimizer.state_dict(),
                'use_alternating_optimization': self.use_alternating_optimization
            })

            if self.use_alternating_optimization and hasattr(self, 'pb_kappa_optimizer'):
                checkpoint['pb_kappa_optimizer_state_dict'] = self.pb_kappa_optimizer.state_dict()
        
        torch.save(checkpoint, path)
    
    def load(self, path: str) -> None:
        """Load algorithm state including PAC-Bayes components."""
        checkpoint = torch.load(path, map_location=self.device)
        
        # Load base SAC components
        super().load(path)
        
        # Load PAC-Bayes specific components
        if self.pac_bayes_active and 'posterior_mean' in checkpoint:
            self.posterior_mean.data.copy_(checkpoint['posterior_mean'])
            self.posterior_log_std.data.copy_(checkpoint['posterior_log_std'])
            self.prior_mean.data.copy_(checkpoint['prior_mean'])
            self.prior_log_std.data.copy_(checkpoint['prior_log_std'])
            
            if 'pb_kappa' in checkpoint:
                self.pb_kappa.data.copy_(checkpoint['pb_kappa'])
            
            # Recreate distributions
            self.posterior = DiagonalGaussian(self.posterior_mean, self.posterior_log_std)
            self.prior = DiagonalGaussian(self.prior_mean, self.prior_log_std)
            
            if 'posterior_optimizer_state_dict' in checkpoint:
                self.posterior_optimizer.load_state_dict(checkpoint['posterior_optimizer_state_dict'])
                
            if (self.use_alternating_optimization and hasattr(self, 'pb_kappa_optimizer') and
                'pb_kappa_optimizer_state_dict' in checkpoint):
                self.pb_kappa_optimizer.load_state_dict(checkpoint['pb_kappa_optimizer_state_dict'])
        
        # Load other PAC-Bayes state
        self.r_max = checkpoint.get('r_max', self.r_max_estimate)
        self.mixing_time = checkpoint.get('mixing_time', 1)
        self.last_pac_bayes_bound = checkpoint.get('last_pac_bayes_bound', 0.0)