import numpy as np
import scipy.signal
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

import time
import os.path as osp

from copy import deepcopy
import itertools
import numpy as np
import torch
from torch.optim import Adam
import time

from logx import EpochLogger, setup_logger_kwargs

from models.common import *
from models.mlp import *
from models.mp import *

from memory import *
from utils import *
import math
from envs import *
from warnings import filterwarnings
from collections import deque

import logging
filterwarnings(
    action="ignore",
    category=DeprecationWarning,
    message="`np.bool8` is a deprecated alias for `np.bool_`",
)

import os
from dotenv import load_dotenv
load_dotenv()

device = os.getenv("device") if os.getenv("device") else "cuda" if torch.cuda.is_available() else "cpu" 
# logger = logging.getLogger(__name__)

def symmetrize_axes(axes):
    y_max = np.abs(axes.get_ylim()).max() + 1
    axes.set_ylim(ymin=-y_max, ymax=y_max)
    
    x_max = np.abs(axes.get_xlim()).max() + 1
    axes.set_xlim(xmin=-x_max, xmax=x_max)

class Agent:
    def __init__(self, exp_name, env_fn, actor_critic=MLPActorCritic, ac_kwargs=dict(), seed=0, 
        steps_per_epoch=4000, epochs=100, replay_size=int(1e6), gamma=0.99, 
        polyak=0.995, lr=1e-3, alpha=0.2, 
        batch_size=256, start_steps=10000, 
        update_after=1000, update_every=16, num_test_episodes=3, max_ep_len=1000, 
        logger_kwargs=dict(), save_freq=1, test_mode=False, early_stopping=False) -> None:
        logger.debug(f"{RED}Building agent{ENDC}") 
        
        self.exp_name = exp_name
        self.gamma = gamma
        self.alpha = alpha
        self.env_fn = env_fn
        self.steps_per_epoch=steps_per_epoch
        self.epochs=epochs
        self.polyak=polyak
        self.batch_size=batch_size
        self.start_steps=start_steps 
        self.update_after=update_after
        self.update_every=update_every
        self.num_test_episodes=num_test_episodes
        self.max_ep_len=max_ep_len
        self.save_freq = save_freq
        self.seed = seed
        self.early_stopping = early_stopping
        self.test_mode = test_mode
        
        if not test_mode: 
            self.logger = EpochLogger(**logger_kwargs)
        configs = locals()
        del configs["self"]
        

        torch.manual_seed(seed)
        np.random.seed(seed)

        self.env, self.test_env = self.env_fn(), self.env_fn()
        logger.debug(f"{GREEN} observation space: {self.env.observation_space}{ENDC}")
        logger.debug(f"{GREEN} action space: {self.env.action_space}{ENDC}")
        obs_dim, act_dim, act_limit = get_env_dims(env=self.env)
        
        # Create actor-critic module and target networks
        logger.debug(f"{RED}actor_critic: {actor_critic}{ENDC}")
        logger.debug(f"{RED}ac_kwargs: {ac_kwargs}{ENDC}")
        print(ac_kwargs)
        self.ac = actor_critic(obs_dim, act_dim, act_limit, **ac_kwargs).to(device)
        self.ac_targ = deepcopy(self.ac).to(device)
        # Freeze target networks with respect to optimizers (only update via polyak averaging)
        for p in self.ac_targ.parameters():
            p.requires_grad = False
            
        # List of parameters for both Q-networks (save this for convenience)
        self.q_params = itertools.chain(self.ac.q1.parameters(), self.ac.q2.parameters())
        
        # Experience buffer
        self.replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)
        
        # Set up optimizers for policy and q-function
        logger.debug(f"{RED}lr: {lr}{ENDC}")
        self.pi_optimizer = Adam(self.ac.pi.parameters(), lr=lr)
        self.q_optimizer = Adam(self.q_params, lr=lr)
        
        # Set up model saving
        if not test_mode:
            # Count variables (protip: try to get a feel for how different size networks behave!)
            self.var_counts = tuple(count_vars(module) for module in [self.ac.pi, self.ac.q1, self.ac.q2])
            self.logger.log('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d \n' %self.var_counts)
            self.logger.save_config(configs)
            self.logger.setup_pytorch_saver(self.ac)
        
        # Metrics
        self.metrics = [
            'EpRet',
            'EpLen',
            'EpLogPiRet',
            'Q1Vals',
            'Q2Vals',
            'LogPi',
            'STD',
            'LossPi',
            'LossQ',
        ]
        print(self.ac.pi)
    
    # Set up function for computing MOP Q-losses
    def update_q(self, data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']

        q1 = self.ac.q1(o,a)
        q2 = self.ac.q2(o,a)
            
        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a2, logp_a2, *_ = self.ac.pi(o2)
            # Target Q-values
            q1_pi_targ = self.ac_targ.q1(o2, a2)
            q2_pi_targ = self.ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + self.gamma * (1 - d) * (q_pi_targ - self.alpha * logp_a2)
        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup)**2).mean()
        loss_q2 = ((q2 - backup)**2).mean()
        loss_q = loss_q1 + loss_q2

        # Useful info for logging
        q_info = dict(Q1Vals=q1.cpu().detach().numpy(), Q2Vals=q2.cpu().detach().numpy())

        # First run one gradient descent step for Q1 and Q2
        self.q_optimizer.zero_grad()
        loss_q.backward()
        self.q_optimizer.step()
        
        # Record things
        self.logger.store(LossQ=loss_q.item(), **q_info)

    # Set up function for computing MOP pi loss
    def update_pi(self, data):
        o= data['obs']
        a, logp_a, pi, mu, std, *_ = self.ac.pi(o)
        
        q1_pi = self.ac.q1(o, a)
        q2_pi = self.ac.q2(o, a)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = (self.alpha * logp_a  - q_pi).mean()

        # Useful info for logging
        pi_info = dict(LogPi=logp_a.cpu().detach().numpy(), STD=std.mean().item())

        total_loss = loss_pi 
        self.pi_optimizer.zero_grad()
        total_loss.backward()
        self.pi_optimizer.step()
        
        # Record things
        self.logger.store(LossPi=loss_pi.item(), **pi_info)
        
            
    def update(self, data):
        # First run one gradient descent step for Q1 and Q2 
        self.update_q(data)

        for p in self.q_params:
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        self.update_pi(data)

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in self.q_params:
            p.requires_grad = True

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(self.ac.parameters(), self.ac_targ.parameters()):
                p_targ.data.mul_(self.polyak)
                p_targ.data.add_((1 - self.polyak) * p.data)
            
    def get_action(self, o, deterministic=False):
        return self.ac.act(torch.as_tensor(o, dtype=torch.float32).to(device), deterministic)

    def test_agent(self):
        from data_collector import CollectorConfig, collect_data
        import os
        import math
        logger.debug(f"{CYAN}Testing agent{ENDC}") 
        # Create test videos directory within the logger's output directory
        test_videos_dir = f"test_videos/epoch_{self.t}"
        
        # Calculate grid dimensions
        grid_cols = min(3, self.num_test_episodes)  # Use at most 3 columns for test videos
        grid_rows = math.ceil(self.num_test_episodes / grid_cols)
        total_episodes = grid_cols * grid_rows
        
        # Configure data collection for test episodes
        config = CollectorConfig(
            max_ep_length=self.max_ep_len * 2,
            save_video=True,
            video_fps=30,
            base_dir=self.logger.output_dir if hasattr(self.logger, 'output_dir') else "data",
            collect_activation_layers=False,
            num_workers=1,  # Sequential processing for test episodes
            device=device,
            env_id=self.test_env.unwrapped.spec.id,
            test_mode=True,
            model_path=os.path.join(self.logger.output_dir, "pyt_save/model.pt") if hasattr(self.logger, 'output_dir') else None,
            bias_config=None,
            random_component=True,
            n_components=self.n_components,
            # Grid video settings
            video_grid_cols=grid_cols,
            video_frame_width=500,  # Reasonable size for grid layout
            video_frame_height=500  # Reasonable size for grid layout
        )
        
        # If we don't have a saved model yet, save the current model
        if not os.path.exists(config.model_path):
            os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
            torch.save(self.ac, config.model_path)
        
        # Collect test episodes data and videos
        total_steps = collect_data(
            config=config,
            dir_name=test_videos_dir,
            num_episodes=self.num_test_episodes, 
            start_from=0
        )
        
        # Log the path to the grid video for easy reference
        grid_video_path = os.path.join(config.base_dir, test_videos_dir, "videos", f"{self.t}.mp4")
        if os.path.exists(grid_video_path):
            self.logger.log(f"Test grid video saved to: {grid_video_path}")
        
        self.logger.log(f"Collected {self.num_test_episodes} test episodes over {total_steps} total steps")

    
    def reward_function(self, reward):
        raise NotImplementedError
     
     
    def _log_plot(self, obs_list, infos, action_infos, t):
        pics_dir = f"data/{self.exp_name}/{self.exp_name}_s{self.seed}/pics_training"
        os.makedirs(pics_dir, exist_ok=True) 
        
        plt.figure() 
        for i in range(len(obs_list)-1):
            x_position, y_position = infos[i].get('ant_x_position', 0), infos[i].get('ant_y_position', 0)
            x2_position, y2_position = infos[i+1].get('ant_x_position', 0), infos[i+1].get('ant_y_position', 0)
            plt.plot((x_position, x2_position), (y_position, y2_position), "b")
        
        for i in range(len(obs_list)-1):
            x_position, y_position = infos[i].get('point_x_position', 0), infos[i].get('point_y_position', 0)
            x2_position, y2_position = infos[i+1].get('point_x_position', 0), infos[i+1].get('point_y_position', 0)
            plt.plot((x_position, x2_position), (y_position, y2_position), "r")
             
        ax = plt.gca()
        symmetrize_axes(ax)
        plt.savefig(f"{pics_dir}/{t}_map.png")
        plt.close()
    
        
    def _save_ep_data(self, obs_list, infos, action_infos, t):
        # self._log_plot(obs_list, infos, action_infos, t)
        
        # info_dir = f"data/{self.exp_name}/{self.exp_name}_s{self.seed}/action_info_training"
        # os.makedirs(info_dir, exist_ok=True) 

        # if action_infos != []:
        #     df_action_infos = pd.DataFrame(action_infos)
        #     df_action_infos.to_csv(f"{info_dir}/{t}_action_infos.csv", index=False)
        pass
     
    
    def train(self):
        # Prepare for interaction with environment
        total_steps = self.steps_per_epoch * self.epochs
        start_time = time.time()
        o, ep_log_a_ret, ep_ret, ep_len = self.env.reset()[0], 0, 0, 0
        obs_list, infos, action_infos = [], [], []
        # For early stopping: track recent episode lengths
        recent_ep_lens = []
        average_period = 20
        early_stop_threshold = 0.90 * self.max_ep_len
         
        # Main loop: collect experience in env and update/log each epoch
        self.t = 0
        while self.t < total_steps:
            obs_list.append(o)            
            if self.t > self.start_steps:
                # a = self.get_action(o, deterministic=False)
                action_info = self.ac.act_extended(torch.as_tensor(np.expand_dims(o, axis=0), dtype=torch.float32).to(device), deterministic=False)
                ep_log_a_ret += action_info['logp_a'][0]
                a = action_info['a'][0]
            else:
                a = self.env.action_space.sample()

            # Step the env
            o2, r, d, _, info = self.env.step(a)
            infos.append(info)
            r = self.reward_function(r)
            ep_ret += r
            ep_len += 1

            d = False if ep_len==self.max_ep_len else d

            # Store experience to replay buffer
            self.replay_buffer.store(o, a, r, o2, d)
            
            o = o2

            # End of trajectory handling
            if d or (ep_len == self.max_ep_len):
                self.logger.store(EpRet=ep_ret, EpLen=ep_len, EpLogPiRet=ep_log_a_ret)
                
                # Early stopping tracking
                if self.early_stopping:
                    recent_ep_lens.append(ep_len)
                    
                    if len(recent_ep_lens) > average_period:
                        recent_ep_lens.pop(0)  # Keep only the 10 most recent episodes
                
                    # Check early stopping condition if we have 10 episodes
                    if len(recent_ep_lens) == average_period:
                        avg_ep_len = sum(recent_ep_lens) / average_period
                        if avg_ep_len >= early_stop_threshold:
                            self.logger.log(f"\nEarly stopping triggered! Average episode length ({avg_ep_len:.2f}) reached {early_stop_threshold:.2f} threshold.")
                            # Save final model
                            self.logger.save_state({'env': self.env}, None)
                            # Log final stats
                            self._log_training(self.t // self.steps_per_epoch + 1, self.t, start_time)
                            return  # Exit training
                
                o, ep_ret, ep_len = self.env.reset()[0], 0, 0
                self._save_ep_data(obs_list, infos, action_infos, self.t)
                obs_list = []
                infos = []
                action_infos = []

            # Update handling
            if self.t >= self.update_after and self.t % self.update_every == 0:
                for j in range(self.update_every):
                    batch = self.replay_buffer.sample_batch(self.batch_size)
                    self.update(data=batch)

            # End of epoch handling
            if (self.t+1) % self.steps_per_epoch == 0:
                epoch = (self.t+1) // self.steps_per_epoch

                # Save model
                if (epoch % self.save_freq == 0) or (epoch == self.epochs):
                    # self.logger.save_state({'env': self.env}, t+1)
                    self.logger.save_state({'env': self.env}, None)

                # Test the performance of the deterministic version of the agent.
                # self.test_agent()

                if self.t < self.update_after: continue
                
                self._log_training(epoch, self.t, start_time)
            self.t += 1
    
    def _log_training(self, epoch, t, start_time):
        # Log info about epoch
        self.logger.log_tabular('Experiment Name', self.exp_name + "_" +str(self.seed))
        self.logger.log_tabular('Epoch', epoch)
        self.logger.log_tabular('TotalEnvInteracts', t)
        self.logger.log_tabular('Time', time.time()-start_time)
        for metric in self.metrics:
            self.logger.log_tabular(metric, average_only=True)
        self.logger.dump_tabular()

class MOP(Agent):
    def __init__(self, exp_name, env_fn, actor_critic=MLPActorCritic, ac_kwargs=dict(), **kwargs) -> None:
        super().__init__(exp_name, env_fn, actor_critic, ac_kwargs, **kwargs)
        
    def reward_function(self, reward):
        return 0

class EGready(MOP):
    def __init__(self, exp_name, env_fn, actor_critic=MLPActorCritic, epsilon=2, ac_kwargs=dict(), **kwargs) -> None:
        
        super().__init__(exp_name, env_fn, actor_critic, ac_kwargs, **kwargs)
        self.exp_name = exp_name
        self.epsilon = epsilon
            
    def get_action(self, o, **kwargs):
        if random.random() < self.epsilon:
            return self.env.action_space.sample()
        else:
            return self.ac.act(torch.as_tensor(o, dtype=torch.float32).to(device), deterministic=True)

class MixtureEntropyAgent(Agent):
    """
    Agent implementation that maximizes mixture entropy while minimizing component entropies.
    Uses a single surrogate value network with multiple outputs to improve training stability.
    Optimized for computational efficiency.
    
    Implements the Mixture Entropy Maximization with Q Update from Replay Buffer algorithm.
    """
    def __init__(
        self, 
        exp_name, 
        env_fn, 
        actor_critic=FastMixturePolicyActorCritic, 
        ac_kwargs=dict(), 
        num_action_samples=32,  # Number of action samples to approximate integrals
        lr=1e-3,  # Learning rate for surrogate value functions
        **kwargs
    ) -> None:
        super().__init__(exp_name, env_fn, actor_critic, ac_kwargs, lr=lr, **kwargs)
        self.num_action_samples = num_action_samples
        
        # Add metrics specific to mixture entropy objective
        self.metrics.append('ComponentIndices')
        self.metrics.append('MixtureLogProb')
        self.metrics.append('ComponentLogProb')
        self.metrics.append('MixtureLogProbRaw')
        self.metrics.append('ComponentLogProbRaw')
        
        # Add new metrics for motor and component statistics 
        # self.metrics.append('MeanOfMeans')
        # self.metrics.append('StdOfMeans')
        self.metrics.append('VarOfMeans')
        # self.metrics.append('VarOfStds')
        
        # self.metrics.append('SurrogateValueLoss')
        self.component_indices = deque(maxlen=1000)
        self.update_counter = 0 
        
        # Define parameter groups for optimization
        self.q_params = itertools.chain(self.ac.q1.parameters(), self.ac.q2.parameters())
        self.n_components = self.ac.pi.n_components

    def update_q(self, data):
        """
        Update Q-functions using the algorithm's approach.
        For each sample in the batch:
        1. Sample component and action for next state
        2. Compute target using Q-value, mixture entropy, and component entropy
        """
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']

        # Compute Q-values for current state-action pairs
        q1 = self.ac.q1(o, a)
        q2 = self.ac.q2(o, a)

        with torch.no_grad():
            # Sample next actions from the current policy for each next state
            a_prime, logp_mixture_prime, logp_component_prime, _, _, _, pi_info_prime = self.ac.pi(o2)
            
            # Compute Q-targets for next state-action pairs
            q1_targ = self.ac_targ.q1(o2, a_prime)
            q2_targ = self.ac_targ.q2(o2, a_prime)
            q_targ = torch.min(q1_targ, q2_targ)
            
            # Calculate target value as per the algorithm:
            # y = r + γ(1-d)[Q(s',a') - log π^m(a'|s') + α * log π^b_k'(a'|s')]
            target_value = q_targ - logp_mixture_prime + self.alpha * logp_component_prime
            
            # Complete backup target: discount and handle terminal states
            backup = self.gamma * (1 - d) * target_value

        # MSE loss against Bellman backup for both Q-networks
        loss_q1 = ((q1 - backup.detach()) ** 2).mean()
        loss_q2 = ((q2 - backup.detach()) ** 2).mean()
        loss_q = loss_q1 + loss_q2

        # Gradient descent step for Q functions
        self.q_optimizer.zero_grad()
        loss_q.backward()
        self.q_optimizer.step()

        # Record metrics for logging
        q_info = dict(Q1Vals=q1.cpu().detach().numpy(), Q2Vals=q2.cpu().detach().numpy())
        self.logger.store(LossQ=loss_q.item(), **q_info)

    def update_pi(self, data):
        o = data['obs']
        
        # Sample fresh component indices and actions from the current policy
        # Modified to capture mu (means) from the policy output
        a_tilde, logp_mixture, logp_component, pi, mu, selected_std, pi_info = self.ac.pi(o)
        
        # Track component indices for logging and analysis
        component_indices = pi_info['indices'].detach().cpu().numpy()
        for idx in component_indices:
            self.component_indices.append(idx)
        
        # Compute Q-values for the sampled actions
        q1 = self.ac.q1(o, a_tilde)
        q2 = self.ac.q2(o, a_tilde)
        q_min = torch.min(q1, q2)
        
        # Compute the objective as per the algorithm:
        # J(θ) = E[Q(s,a) - log π^m(a|s) + α * log π^b_k(a|s)]
        r = -logp_mixture + self.alpha * logp_component
        objective = q_min + r
        loss_pi = -objective.mean()  # Negate for gradient ascent
        
        
        # Update policy parameters
        self.pi_optimizer.zero_grad()
        loss_pi.backward()
        self.pi_optimizer.step()
        
        # --- Added statistics calculation using all_mus and all_stds ---
        all_mus = pi_info.get('all_mus')  # Should be [batch_size, n_components, action_dim]
        # Compute variance of means among components for each sample and motor
        # var over component dimension -> [batch_size, action_dim]
        var_of_means_per_sample = all_mus.var(dim=1)
        var_of_means = var_of_means_per_sample.mean(0).mean().item()  
        
        # Compute mean probability for each mixture component
        component_probs = pi_info["mixing_probs"]  # Shape should be [batch_size, n_components]
        mean_component_probs = component_probs.mean(dim=0)  # Mean prob for each component
        
        # Prepare statistics dictionary
        motor_stats_dict = {
            'VarOfMeans': var_of_means,    # Overall variance of means
        }
        
        
        # Add component probabilities to the stats dictionary
        for i, prob in enumerate(mean_component_probs):
            motor_stats_dict[f'ComponentProb{i}'] = prob.item()
        
        
        # Log metrics for monitoring
        pi_info_dict = dict(
            STD=selected_std.mean().item(),
            LogPi=logp_mixture.detach().cpu().numpy(),
            
            MixtureLogProb=logp_mixture.mean().item(),
            ComponentLogProb=logp_component.mean().item(),
            ComponentLogProbRaw = pi_info['logp_comp_raw'].mean().item(),
            MixtureLogProbRaw = pi_info['logp_raw'].mean().item(),
            ComponentIndices=len(set(self.component_indices))
        )
        
        # Merge the dictionaries and log all metrics
        combined_info_dict = {**pi_info_dict, **motor_stats_dict}
        self.logger.store(LossPi=loss_pi.item(), **combined_info_dict)
        
    def update(self, data):
        """
        The main update method that follows the algorithm's structure:
        1. Update Q-functions
        2. Update policy
        3. Update target networks
        """
        # Enable gradients for Q parameters
        for p in self.q_params:
            p.requires_grad = True
        
        # First, update the Q-functions
        self.update_q(data)
        
        # Temporarily disable gradient computation for Q networks
        for p in self.q_params:
            p.requires_grad = False
        
        # Update the policy at each step
        self.update_pi(data)
        
        # Finally, update target networks by polyak averaging
        with torch.no_grad():
            for p, p_targ in zip(self.ac.parameters(), self.ac_targ.parameters()):
                p_targ.data.mul_(self.polyak)
                p_targ.data.add_((1 - self.polyak) * p.data)    
                    
        self.update_counter += 1
    
    def reward_function(self, reward):
        # For pure entropy maximization, we ignore external rewards
        return 0

class MixtureEntropyAgentLimited(Agent):
    """
    Agent implementation that maximizes mixture entropy while minimizing component entropies.
    Uses a single surrogate value network with multiple outputs to improve training stability.
    Optimized for computational efficiency.
    
    Implements the Mixture Entropy Maximization with Q Update from Replay Buffer algorithm.
    """
    def __init__(
        self, 
        exp_name, 
        env_fn, 
        actor_critic=FastMixturePolicyActorCritic, 
        ac_kwargs=dict(), 
        num_action_samples=32,  # Number of action samples to approximate integrals
        lr=1e-3,  # Learning rate for surrogate value functions
        **kwargs
    ) -> None:
        super().__init__(exp_name, env_fn, actor_critic, ac_kwargs, lr=lr, **kwargs)
        self.num_action_samples = num_action_samples
        
        # Add metrics specific to mixture entropy objective
        self.metrics.append('ComponentIndices')
        self.metrics.append('MixtureLogProb')
        self.metrics.append('ComponentLogProb')
        self.metrics.append('MixtureLogProbRaw')
        self.metrics.append('ComponentLogProbRaw')
        
        # Add new metrics for motor and component statistics 
        # self.metrics.append('MeanOfMeans')
        # self.metrics.append('StdOfMeans')
        self.metrics.append('VarOfMeans')
        self.metrics.append('MixVar')
        # self.metrics.append('VarOfStds')
        
        self.metrics.append('MinLambda')
        self.metrics.append('MaxLambda')
        self.metrics.append('MeanLambda')
         
        # self.metrics.append('SurrogateValueLoss')
        self.component_indices = deque(maxlen=1000)
        self.update_counter = 0 
        
        # Define parameter groups for optimization
        self.q_params = itertools.chain(self.ac.q1.parameters(), self.ac.q2.parameters())
        self.n_components = self.ac.pi.n_components

        # --- Dual variables for per-component entropy cap ---
        self.C = -1000 #0.001 #kwargs.get('entropy_cap', 0.5)        # cap C in nats (sum over action dims)
        self.dual_lr = kwargs.get('dual_lr', 1e-2)     # dual step size
        self.lmbda = torch.zeros(self.n_components, device=device)  # λ_i ≥ 0

    def update_lmbda(self, indices: torch.Tensor, logp_component: torch.Tensor):
        """
        SAC-style dual update for per-component entropy cap:
            λ_i ← [ λ_i + η ( E[-log π_i] - C ) ]_+
        Args:
            indices: LongTensor [batch] of sampled component ids z for each sample
            logp_component: Tensor [batch] of log π_{z}(a|s) for the sampled component
        """
        ent_terms = (-logp_component).detach()  # estimate of per-sample entropy contribution
        with torch.no_grad():
            for i in range(self.n_components):
                mask = (indices == i)
                if mask.any():
                    H_i_hat = ent_terms[mask].mean()
                    self.lmbda[i] = torch.clamp(
                        self.lmbda[i] + self.dual_lr * (H_i_hat - self.C),
                        min=0.0
                    )
        self.logger.store(MinLambda = self.lmbda.min().item(), MaxLambda=self.lmbda.max().item(), MeanLambda=self.lmbda.mean().item())
        self.logger.store()
    def update_q(self, data):
        """
        Update Q-functions using the algorithm's approach.
        For each sample in the batch:
        1. Sample component and action for next state
        2. Compute target using Q-value, mixture entropy, and component entropy
        """
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']

        # Compute Q-values for current state-action pairs
        q1 = self.ac.q1(o, a)
        q2 = self.ac.q2(o, a)

        with torch.no_grad():
            # Sample next actions from the current policy for each next state
            a_prime, logp_mixture_prime, logp_component_prime, _, _, _, pi_info_prime = self.ac.pi(o2)
            
            # Compute Q-targets for next state-action pairs
            q1_targ = self.ac_targ.q1(o2, a_prime)
            q2_targ = self.ac_targ.q2(o2, a_prime)
            q_targ = torch.min(q1_targ, q2_targ)
            
            # Calculate target value as per the algorithm:
            # y = r + γ(1-d)[Q(s',a') - log π^m(a'|s') + λ_k' * log π^b_k'(a'|s')]
            lam_prime = self.lmbda[pi_info_prime['indices']]  # shape: [batch]
            target_value = q_targ - logp_mixture_prime + lam_prime * logp_component_prime
            
            # Complete backup target: discount and handle terminal states
            backup = self.gamma * (1 - d) * target_value

        # MSE loss against Bellman backup for both Q-networks
        loss_q1 = ((q1 - backup.detach()) ** 2).mean()
        loss_q2 = ((q2 - backup.detach()) ** 2).mean()
        loss_q = loss_q1 + loss_q2

        # Gradient descent step for Q functions
        self.q_optimizer.zero_grad()
        loss_q.backward()
        self.q_optimizer.step()

        # Record metrics for logging
        q_info = dict(Q1Vals=q1.cpu().detach().numpy(), Q2Vals=q2.cpu().detach().numpy())
        self.logger.store(LossQ=loss_q.item(), **q_info)

    def update_pi(self, data):
        o = data['obs']
        
        # Sample fresh component indices and actions from the current policy
        # Modified to capture mu (means) from the policy output
        a_tilde, logp_mixture, logp_component, pi, mu, selected_std, pi_info = self.ac.pi(o)
        
        # Track component indices for logging and analysis
        component_indices = pi_info['indices'].detach().cpu().numpy()
        for idx in component_indices:
            self.component_indices.append(idx)
        
        # Compute Q-values for the sampled actions
        q1 = self.ac.q1(o, a_tilde)
        q2 = self.ac.q2(o, a_tilde)
        q_min = torch.min(q1, q2)
        
        # Compute the objective as per the algorithm:
        # J(θ) = E[Q(s,a) - log π^m(a|s) + λ_k * log π^b_k(a|s)]
        lam = self.lmbda[pi_info['indices']]            # shape: [batch]
        objective = q_min - logp_mixture + lam * logp_component
        loss_pi = -objective.mean()  # Negate for gradient ascent
        
        # Update policy parameters
        self.pi_optimizer.zero_grad()
        loss_pi.backward()
        self.pi_optimizer.step()
        
        # --- Dual update (separate function) ---
        self.update_lmbda(pi_info['indices'], logp_component)
        
        # --- Added statistics calculation using all_mus and all_stds ---
        all_mus = pi_info.get('all_mus')  # Should be [batch_size, n_components, action_dim]
        # Compute variance of means among components for each sample and motor
        # var over component dimension -> [batch_size, action_dim]
        var_of_means_per_sample = all_mus.var(dim=1)
        var_of_means = var_of_means_per_sample.mean(0).mean().item()  
        
        # Compute mean probability for each mixture component
        component_probs = pi_info["mixing_probs"]  # Shape should be [batch_size, n_components]
        mean_component_probs = component_probs.mean(dim=0)  # Mean prob for each component
        
        # Prepare statistics dictionary
        motor_stats_dict = {
            'VarOfMeans': var_of_means,    # Overall variance of means
        }
        
        # Add component probabilities to the stats dictionary
        for i, prob in enumerate(mean_component_probs):
            motor_stats_dict[f'ComponentProb{i}'] = prob.item()
        
        # Optional: Log lambdas
        # for i in range(self.n_components):
        #     motor_stats_dict[f'Lambda{i}'] = float(self.lmbda[i].item())
        
        # Log metrics for monitoring
        pi_info_dict = dict(
            STD=selected_std.mean().item(),
            LogPi=logp_mixture.detach().cpu().numpy(),
            MixVar=pi_info['mix_var'].mean().item(),
            MixtureLogProb=logp_mixture.mean().item(),
            ComponentLogProb=logp_component.mean().item(),
            ComponentLogProbRaw = pi_info['logp_comp_raw'].mean().item(),
            MixtureLogProbRaw = pi_info['logp_raw'].mean().item(),
            ComponentIndices=len(set(self.component_indices))
        )
        
        # Merge the dictionaries and log all metrics
        combined_info_dict = {**pi_info_dict, **motor_stats_dict}
        self.logger.store(LossPi=loss_pi.item(), **combined_info_dict)
        
    def update(self, data):
        """
        The main update method that follows the algorithm's structure:
        1. Update Q-functions
        2. Update policy
        3. Update target networks
        """
        # Enable gradients for Q parameters
        for p in self.q_params:
            p.requires_grad = True
        
        # First, update the Q-functions
        self.update_q(data)
        
        # Temporarily disable gradient computation for Q networks
        for p in self.q_params:
            p.requires_grad = False
        
        # Update the policy at each step
        self.update_pi(data)
        
        # Finally, update target networks by polyak averaging
        with torch.no_grad():
            for p, p_targ in zip(self.ac.parameters(), self.ac_targ.parameters()):
                p_targ.data.mul_(self.polyak)
                p_targ.data.add_((1 - self.polyak) * p.data)    
                    
        self.update_counter += 1
    
    def reward_function(self, reward):
        # For pure entropy maximization, we ignore external rewards
        return 0

class MixtureEntropyMOP(MixtureEntropyAgent):
    """
    Maximum entropy mixture policy with the maximum entropy objective.
    MOP variant that ignores environment rewards.
    """
    def __init__(self, exp_name, env_fn, actor_critic=FastMixturePolicyActorCritic, ac_kwargs=dict(), **kwargs) -> None:
        super().__init__(exp_name, env_fn, actor_critic, ac_kwargs, **kwargs)
        
    def reward_function(self, reward):
        # Ignore environment rewards - pure entropy maximization
        return 0

def build_agent(args):
    torch.set_num_threads(torch.get_num_threads())
    exp_name = args.exp_name + "_a_" + convert_to_underscore(args.alpha) + "_g_" + convert_to_underscore(args.gamma) + "_nc_" + str(args.n_components) 
    exp_name +=  "_eplen_" + str(args.eplen)
    
    logger_kwargs = setup_logger_kwargs(exp_name, args.seed)

    env_fn = build_env(args, render_mode="rgb_array")

    if args.ac_model == "mlp":
        logger.info(f"{YELLOW}Using MLP model{ENDC}")
        agent = MOP
        agent_kwargs = dict()
        action_critic_model = MLPActorCritic
        ac_kwargs = dict() # hidden_sizes=args.layers,
    
    elif args.ac_model == "mpf":
        logger.info(f"{YELLOW}Using Mixture Policy{ENDC}")
        agent = MixtureEntropyMOP
        agent_kwargs = dict(
        #     alpha=args.alpha
        )
        action_critic_model = FastMixturePolicyActorCritic
        ac_kwargs = dict(
            hidden_sizes=args.layers, 
            dropout_rate=args.dropout_rate, 
            n_components=args.n_components, 
        )   
    else:
        raise f"Model not known: {args.ac_model}"
        
    logger.info(f"{YELLOW}Using MOP model{ENDC}")
    logger.info(f"{YELLOW}Using {args.ac_model} model{ENDC}")
    logger.info(f"{YELLOW}actor ciritic arguments {ac_kwargs}{ENDC}")
    
    model = agent(
        exp_name=exp_name,
        env_fn=env_fn,
        actor_critic=action_critic_model,
        ac_kwargs=ac_kwargs,
        alpha=args.alpha,
        gamma=args.gamma,
        seed=args.seed,
        epochs=args.epochs,
        batch_size=args.bs,
        steps_per_epoch=args.steps_per_epoch,
        start_steps=args.start_steps,
        lr=args.lr,
        logger_kwargs=logger_kwargs,
        max_ep_len=args.eplen,
        update_every=16,
        test_mode=args.test,
        early_stopping=args.early_stopping,
        **agent_kwargs
    )
    
    return model 
    