# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp
import dataclasses
import logging

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from hydra.core.config_store import ConfigStore
import omegaconf

from url_benchmark.in_memory_replay_buffer import ReplayBuffer
from .encoder_rl import DiscreteRLAgent, DiscreteRLAgentConfig

logger = logging.getLogger(__name__)
MetaDict = tp.Mapping[str, np.ndarray]


@dataclasses.dataclass
class DiaynConfig:
    # @package agent
    _target_: str = "url_benchmark.agent.diayn.DiaynAgent"
    name: str = "diayn_agent"
    reward_free: bool = omegaconf.II("reward_free")
    custom_reward: tp.Optional[str] = omegaconf.II("custom_reward")
    obs_type: str = omegaconf.MISSING  # to be specified later
    obs_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    action_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    device: str = omegaconf.II("device")  # ${device}
    lr: float = 1e-4
    use_tb: bool = omegaconf.II("use_tb")  # ${use_tb}
    use_wandb: bool = omegaconf.II("use_wandb")  # ${use_wandb}
    use_hiplog: bool = omegaconf.II("use_hiplog")  # ${use_wandb}
    num_expl_steps: int = omegaconf.MISSING
    hidden_dim: int = 1024
    feature_dim: int = 512
    stddev_schedule: str = "0.2"  # "linear(1,0.2,200000)"
    stddev_clip: float = 0.3  # 1.0
    nstep: int = 1
    batch_size: int = 128  # 256 for pixels
    goal_space: tp.Optional[str] = omegaconf.II("goal_space")
    fb_reward: bool = False
    future_ratio: float = 0
    preprocess: bool = False
    add_trunk: bool = False
    supervised: bool = True
    update_every_steps: float = 2   
    critic_target_tau: float = 0.01
    discount: float = 0.99
    
    # DIAYN specific parameters
    num_skills: int = 10  # Number of skills to learn
    use_state: bool = True  # Whether to use state-transition pairs or just next state
    spectral_norm: bool = False  # Whether to use spectral normalization
    skill_lr: float = 1e-4  # Learning rate for skill discriminator
    intrinsic_reward_scale: float = 1.0  # Scale factor for intrinsic rewards

cs = ConfigStore.instance()
cs.store(group="agent", name="diayn_agent", node=DiaynConfig)


class OneHotDynamicsModel(nn.Module):
    """
    Predicts single step dynamics with one-hot observations and actions.
    Input: one-hot observation + one-hot action
    Output: one-hot next observation
    """
    
    def __init__(self, obs_dim: int, action_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        
        # Input dimension is obs_dim + action_dim (concatenated one-hot vectors)
        input_dim = obs_dim + action_dim
        
        # MLP to predict next observation
        self.dynamics_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, obs_dim)
        )
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            nn.init.constant_(module.bias, 0.0)
    
    def action_to_onehot(self, action: torch.Tensor) -> torch.Tensor:
        """Convert discrete action to one-hot encoding."""
        # action should be shape (batch_size,) with integer values
        # import pdb; pdb.set_trace()
        batch_size = action.shape[0]
        action_onehot = torch.zeros(batch_size, self.action_dim, device=action.device)
        action_onehot.scatter_(1, action.long().unsqueeze(1), 1)
        return action_onehot
    
    def forward(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """
        Forward pass to predict next observation.
        
        Args:
            obs: One-hot observation tensor of shape (batch_size, obs_dim)
            action: Discrete action tensor of shape (batch_size,) or one-hot of shape (batch_size, action_dim)
        
        Returns:
            next_obs_logits: Logits for next observation of shape (batch_size, obs_dim)
        """
        # Convert action to one-hot if it's discrete
        if action.dim() == 1 or (action.dim() == 2 and action.shape[1] == 1):
            action_onehot = self.action_to_onehot(action.squeeze(1))
        else:
            action_onehot = action
        
        # Concatenate observation and action
        obs_action = torch.cat([obs, action_onehot], dim=1)
        
        # Predict next observation logits
        next_obs_logits = self.dynamics_net(obs_action)
        
        return next_obs_logits
    
    def predict(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """
        Predict next observation with softmax.
        
        Returns:
            next_obs_probs: Probability distribution over next observations
        """
        next_obs_logits = self.forward(obs, action)
        return F.softmax(next_obs_logits, dim=1)
    
    def predict_deterministic(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """
        Predict next observation deterministically (argmax).
        
        Returns:
            next_obs_onehot: One-hot next observation
        """
        next_obs_logits = self.forward(obs, action)
        next_obs_idx = torch.argmax(next_obs_logits, dim=1)
        
        batch_size = obs.shape[0]
        next_obs_onehot = torch.zeros_like(obs)
        next_obs_onehot.scatter_(1, next_obs_idx.unsqueeze(1), 1)
        
        return next_obs_onehot
    
    def compute_loss(self, obs: torch.Tensor, action: torch.Tensor, next_obs: torch.Tensor) -> torch.Tensor:
        """
        Compute cross-entropy loss for dynamics prediction.
        
        Args:
            obs: Current one-hot observation
            action: Discrete action
            next_obs: Target one-hot next observation
            
        Returns:
            loss: Cross-entropy loss
        """
        next_obs_logits = self.forward(obs, action)
        next_obs_targets = torch.argmax(next_obs, dim=1)
        loss = F.cross_entropy(next_obs_logits, next_obs_targets)
        return loss


class DiaynDiscriminator(nn.Module):
    """Discriminator network for DIAYN that predicts skills from state transitions."""
    
    def __init__(self, obs_dim: int, num_skills: int, hidden_dim: int = 1024, 
                 use_state: bool = True, spectral_norm: bool = False, device: str = "cpu"):
        super(DiaynDiscriminator, self).__init__()

        self.skill_dim = num_skills
        self.use_state = use_state
        self.device = device
        
        # Input size: obs_dim + obs_dim if using state transitions, else just obs_dim
        input_size = obs_dim * 2 if use_state else obs_dim
        
        # Build discriminator network
        if spectral_norm:
            self.skill_pred_net = nn.Sequential(
                nn.utils.spectral_norm(nn.Linear(input_size, hidden_dim)),
                                                nn.ReLU(),
                nn.utils.spectral_norm(nn.Linear(hidden_dim, hidden_dim)),
                                                nn.ReLU(),
                nn.utils.spectral_norm(nn.Linear(hidden_dim, num_skills))
            )
        else:
            self.skill_pred_net = nn.Sequential(
                nn.Linear(input_size, hidden_dim),
                                                nn.ReLU(),
                                                nn.Linear(hidden_dim, hidden_dim),
                                                nn.ReLU(),
                nn.Linear(hidden_dim, num_skills)
            )

        
        # Loss criterion and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.skill_pred_net.parameters(), lr=1e-4)


    def forward(self, obs: torch.Tensor, next_obs: torch.Tensor) -> torch.Tensor:
        """Predict skill probabilities from state transitions."""
        if self.use_state:
            # Concatenate current and next observations
            input_tensor = torch.cat([obs, next_obs], dim=-1)
        else:
            # Use only next observation
            input_tensor = next_obs
        
        skill_logits = self.skill_pred_net(input_tensor)
        return skill_logits
    
    def get_intrinsic_reward(self, obs: torch.Tensor, next_obs: torch.Tensor, 
                           skill: torch.Tensor) -> torch.Tensor:
        """Compute intrinsic reward based on skill prediction accuracy."""
        skill_logits = self.forward(obs, next_obs)
        
        # Get the skill index from one-hot encoding
        skill_idx = torch.argmax(skill, dim=1)
        
        # Intrinsic reward: log probability of predicted skill + log(num_skills)
        # This encourages the agent to take actions that make skills distinguishable
        log_probs = F.log_softmax(skill_logits, dim=1)
        intrinsic_reward = log_probs[torch.arange(skill_logits.size(0)), skill_idx] + np.log(self.skill_dim)
        
        return intrinsic_reward
    
    def update_discriminator(self, obs: torch.Tensor, next_obs: torch.Tensor, 
                           skills: torch.Tensor) -> tp.Dict[str, float]:
        """Update the discriminator network."""
        skill_logits = self.forward(obs, next_obs)
        skill_idx = torch.argmax(skills, dim=1)
        
        # Compute loss and accuracy
        loss = self.criterion(skill_logits, skill_idx)
        with torch.no_grad():
            pred_skills = torch.argmax(skill_logits, dim=1)
            accuracy = (pred_skills == skill_idx).float().mean()
        
        # Update discriminator
        self.optimizer.zero_grad()
        loss.backward()
        grad = 0.0
        for param in self.skill_pred_net.parameters():
                grad += param.grad.data.norm(2).item()
        self.optimizer.step()
        
        return {
            "discriminator_loss": loss.item(),
            "discriminator_accuracy": accuracy.item(),
            "discriminator_grad": grad
        }


class DiaynAgent:
    """
    DIAYN (Diversity is All You Need) Agent that learns diverse skills without external rewards.
    
    The agent learns to maximize an information-theoretic objective that encourages
    diverse behaviors by predicting skills from state transitions.
    """

    def __init__(self, **kwargs: tp.Any):
        cfg = DiaynConfig(**kwargs)
        self.cfg = cfg
        assert len(cfg.action_shape) == 1
        self.action_dim = cfg.action_shape[0]
        self.obs_dim = cfg.obs_shape[0]
        
        # Initialize skill discriminator
        self.discriminator = DiaynDiscriminator(
            obs_dim=self.obs_dim,
            num_skills=cfg.num_skills,
            hidden_dim=cfg.hidden_dim,
            use_state=cfg.use_state,
            spectral_norm=cfg.spectral_norm,
            device=cfg.device
        ).to(cfg.device)

        self.dynamics_model = OneHotDynamicsModel(
            obs_dim=self.obs_dim,
            action_dim=self.action_dim,
            hidden_dim=256
        ).to(cfg.device)
        
        self.dynamics_optimizer = torch.optim.Adam(
            self.dynamics_model.parameters(), 
            lr=cfg.lr
        )

        # Initialize skill policies (one RL agent per skill)
        self.skill_policies = {}
        self.skill_policy_configs = {}
        
        for skill_idx in range(cfg.num_skills):
            # Create RL config for this skill
            skill_config = DiscreteRLAgentConfig()
            skill_config.action_shape = cfg.action_shape
            skill_config.hidden_dim = cfg.hidden_dim
            skill_config.lr = cfg.lr
            skill_config.device = cfg.device
            skill_config.batch_size = cfg.batch_size
            skill_config.update_every_steps = cfg.update_every_steps
            skill_config.target_tau = cfg.critic_target_tau
            skill_config.num_expl_steps = cfg.num_expl_steps
            skill_config.expl_eps = 0.1  # Exploration epsilon for this skill
            skill_config.q_type = "full"
            
            # Create RL agent for this skill
            skill_policy = DiscreteRLAgent(self.obs_dim, skill_config)
            skill_policy.load_encoder(nn.Identity())  # Use identity encoder for now
            
            self.skill_policies[skill_idx] = skill_policy
            self.skill_policy_configs[skill_idx] = skill_config
        
        # Skill sampling utilities
        self.skill_onehot = np.eye(cfg.num_skills, dtype=np.float32)
        self.current_skill = None
        self.current_skill_idx = None
        self.best_skill = None
        
        # Feature network for fb_reward (optional)
        self.feature_net: tp.Optional[nn.Module] = None
        if self.cfg.fb_reward:
            self.feature_net = nn.Identity()
            self.feature_net.eval()
        
        self.train()

    def train(self, training: bool = True) -> None:
        """Set training mode for the agent."""
        self.training = training
        self.discriminator.train(training)
        
        # Set training mode for all skill policies
        for skill_policy in self.skill_policies.values():
            skill_policy.train(training)

    def sample_skill(self, batch_size: int = 1) -> np.ndarray:
        """Sample random skills for exploration."""
        skill_indices = np.random.randint(0, self.cfg.num_skills, size=batch_size)
        return self.skill_onehot[skill_indices]
    
    def get_best_skill(self) -> tp.Optional[np.ndarray]:
        """Get the skill that achieved the best reward during inference."""
        return self.best_skill

    def act(self, obs: np.ndarray, meta: MetaDict, step: int, eval_mode: bool = False) -> np.ndarray:
        """Select action based on current skill using skill-specific policy."""
        # Select or sample current skill
        if self.current_skill is None:
            # Use best skill if available and in eval mode, otherwise sample random skill
            if eval_mode and self.best_skill is not None:
                self.current_skill = self.best_skill
                self.current_skill_idx = np.argmax(self.best_skill)
        else:
                self.current_skill = self.sample_skill(1)[0]
                self.current_skill_idx = np.argmax(self.current_skill)
        
        # Get action from skill-specific policy
        skill_policy = self.skill_policies[self.current_skill_idx]
        action = skill_policy.act(obs, meta, step, eval_mode)
        
        # Convert to numpy array if needed
        if isinstance(action, (int, float)):
            action = np.array([action])
        
        return action

    def update_agent(self, obs: torch.Tensor, action: torch.Tensor, 
                    next_obs: torch.Tensor, skills: torch.Tensor) -> tp.Dict[str, float]:
        """Update the DIAYN agent components."""
        metrics = {}
        
        # Update discriminator
        disc_metrics = self.discriminator.update_discriminator(obs, next_obs, skills)
        metrics.update(disc_metrics)
        
        # Compute intrinsic rewards for policy learning
        intrinsic_rewards = self.discriminator.get_intrinsic_reward(obs, next_obs, skills)
        metrics["intrinsic_reward_mean"] = intrinsic_rewards.mean().item()
        
        # Update skill policies using intrinsic rewards
        policy_metrics = self.update_skill_policies(obs, action, next_obs, skills, intrinsic_rewards)
        metrics.update(policy_metrics)
        
        return metrics
    
    def update_skill_policies(self, obs: torch.Tensor, action: torch.Tensor, 
                             next_obs: torch.Tensor, skills: torch.Tensor, 
                             intrinsic_rewards: torch.Tensor) -> tp.Dict[str, float]:
        """Update skill policies using DIAYN intrinsic rewards."""
        metrics = {}
        
        # Group data by skill
        skill_data = {}
        for i in range(obs.size(0)):
            skill_idx = torch.argmax(skills[i]).item()
            if skill_idx not in skill_data:
                skill_data[skill_idx] = {
                    'obs': [], 'action': [], 'next_obs': [], 'reward': []
                }
            skill_data[skill_idx]['obs'].append(obs[i])
            skill_data[skill_idx]['action'].append(action[i])
            skill_data[skill_idx]['next_obs'].append(next_obs[i])
            skill_data[skill_idx]['reward'].append(intrinsic_rewards[i])
        
        # Update each skill policy
        for skill_idx, data in skill_data.items():
            if len(data['obs']) < 2:  # Need at least 2 samples for meaningful update
                continue
                
            # Convert to tensors
            skill_obs = torch.stack(data['obs'])
            skill_action = torch.stack(data['action'])  # Add action dimension
            skill_next_obs = torch.stack(data['next_obs'])
            skill_reward = torch.stack(data['reward']).unsqueeze(1)  # Add reward dimension
            skill_discount = torch.ones_like(skill_reward) * self.cfg.discount
            
            # Get skill policy
            skill_policy = self.skill_policies[skill_idx]
            
            # Create reward function for this skill
            def skill_reward_fn(obs_tensor):
                # For simplicity, use the precomputed intrinsic rewards
                # In a more sophisticated implementation, this would recompute rewards
                return skill_reward.squeeze()
            
            # Update skill policy
            # try:
            skill_metrics = skill_policy.update_q(
                obs=skill_obs,
                action=skill_action,
                reward=skill_reward,
                discount=skill_discount,
                next_obs=skill_next_obs,
                step=0  # Step counter not used in this context
            )
            
            # Add skill-specific metrics
            for key, value in skill_metrics.items():
                metrics[f"skill_{skill_idx}_{key}"] = value
                    
            # except Exception as e:
            #     logger.warning(f"Failed to update skill {skill_idx}: {e}")
            #     continue
        
        return metrics

    def update_dynamics_model(self, obs: torch.Tensor, action: torch.Tensor, next_obs: torch.Tensor) -> tp.Dict[str, float]:
        """Update the one-hot dynamics model."""
        metrics = {}
        
        # Compute loss
        loss = self.dynamics_model.compute_loss(obs, action, next_obs)
        
        # Update model
        self.dynamics_optimizer.zero_grad()
        loss.backward()
        self.dynamics_optimizer.step()
        
        metrics["dynamics_loss"] = loss.item()
        return metrics
    
    def update(self, replay_loader: ReplayBuffer, step: int) -> tp.Dict[str, float]:
        """Main update method using URL benchmark replay buffer conventions."""
        metrics: tp.Dict[str, float] = {}
        
        # Sample batch from replay buffer
        batch = replay_loader.sample(self.cfg.batch_size)
        batch = batch.to(self.cfg.device)
        
        obs = batch.obs
        action = batch.action
        next_obs = batch.next_obs
        
        if step < 10000:
            metrics.update(self.update_dynamics_model(obs, action, next_obs))
            if step % 500 == 0:
                print(f"Dynamics training step {step}, loss: {metrics['dynamics_loss']}")
            return metrics
        else:
            # Sample skills for the batch (in practice, these would come from the policy)
            skills = torch.tensor(self.sample_skill(obs.size(0)), device=self.cfg.device)

            # get next_obs by exectuing these skills (there is no function to do that)
            for i in range(obs.size(0)):
                skill_idx = torch.argmax(skills[i]).item()
                skill_policy = self.skill_policies[skill_idx]
                skill_q_value = skill_policy.q_function(obs[i].unsqueeze(0))
                skill_action = torch.argmax(skill_q_value, dim=1).item()
                action[i] = torch.tensor(skill_action, device=self.cfg.device, dtype=torch.float32)
                # Execute action to get next_obs
                # import pdb; pdb.set_trace()
                next_obs[i] = self.dynamics_model.predict_deterministic(obs[i].unsqueeze(0), action[i].unsqueeze(0)).squeeze(0)

            # Update agent components
            agent_metrics = self.update_agent(obs, action, next_obs, skills)
            metrics.update(agent_metrics)

            if step % 50 == 0:
                print(f"DIAYN training step {step}, disc metrics: {metrics['discriminator_loss']}, acc: {metrics['discriminator_accuracy']}, grad: {metrics['discriminator_grad']}")
                # print(self.discriminator(obs, next_obs))
            
            return {}
        
    def check_policies(self, env, step, work_dir):
        # Plot all skill policies
        for skill_idx, skill_policy in self.skill_policies.items():
            q_f = skill_policy.q_function
            state_list = env.get_state_list()
            obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
            # print(obs_list)
            # print(len(state_list))
            obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
            # print(obs_list.shape, goal.shape)
            q_list = q_f(obs_list).detach()
            v_list = torch.max(q_list, dim=1)[0]
            # # v_list = v_list
            a_list = torch.argmax(q_list, dim=1).cpu()
            # v_list = {}
            # a_list = {}
            # for i in range(len(state_list)):
            #     v_list[state_list[i]] = torch.max(q_list[i]).item()
            #     a_list[state_list[i]] = torch.argmax(q_list[i]).cpu()
            # print(v_list, a_list)
            env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"check_policy_{step}_{skill_idx}_v_function")

    def inference(self, replay_loader: ReplayBuffer, infer_logger, q_pos, q_neg, 
                 reward_fn: tp.Callable[[torch.Tensor], torch.Tensor]) -> tp.Dict[str, float]:
        """Run inference to find the skill z that achieves the best reward."""
        metrics = {}
        
        # Sample a batch from replay buffer
        batch = replay_loader.sample(self.cfg.batch_size)
        batch = batch.to(self.cfg.device)
        
        obs = batch.obs
        # next_obs = batch.next_obs
        
        # obs = torch.tensor(obs, device=self.cfg.device, dtype=torch.float32)
        # Compute rewards for all possible skills
        best_rewards = []
        best_skills = []
        
        for skill_idx in range(self.cfg.num_skills):
            # Create one-hot skill vector
            # print(skill_idx)
            skill = torch.zeros(obs.size(0), self.cfg.num_skills, device=self.cfg.device)
            skill[:, skill_idx] = 1.0

            skill_policy = self.skill_policies[skill_idx]
            skill_q_value = skill_policy.q_function(obs)
            skill_action = torch.argmax(skill_q_value, dim=1)
            next_obs = self.dynamics_model.predict_deterministic(obs, skill_action.unsqueeze(1))
            
            # Compute intrinsic reward for this skill
            intrinsic_reward = self.discriminator.get_intrinsic_reward(obs, next_obs, skill)
            
            # If external reward function is provided, combine with intrinsic reward
            if reward_fn is not None:
                external_reward = reward_fn(next_obs.cpu())
                total_reward = external_reward
            else:
                total_reward = intrinsic_reward
            
            best_rewards.append(total_reward.mean().item())
            best_skills.append(skill_idx)

        # print('Best rewards for each skill:')
        # print(best_rewards)
        # print('Best skills:')
        print(np.argmax(best_rewards))

        
        # Find the skill with the highest average reward
        best_skill_idx = np.argmax(best_rewards)
        best_skill_reward = best_rewards[best_skill_idx]
        
        # Store the best skill for future use
        self.best_skill = self.skill_onehot[best_skill_idx]
        
        # metrics.update({
        #     "best_skill_idx": best_skill_idx,
        #     "best_skill_reward": best_skill_reward,
        #     "skill_rewards": best_rewards,
        #     "skill_diversity": np.std(best_rewards)  # Measure of skill diversity
        # })
        
        return best_skill_idx
    
    def find_best_skill_for_goal(self, replay_loader: ReplayBuffer, goal: np.ndarray, 
                                reward_fn: tp.Callable[[torch.Tensor], torch.Tensor]) -> tp.Dict[str, float]:
        """Find the skill that best achieves a specific goal."""
        metrics = {}
        
        # Sample multiple batches to get robust estimates
        num_batches = 10
        skill_rewards = [[] for _ in range(self.cfg.num_skills)]
        
        for _ in range(num_batches):
            batch = replay_loader.sample(self.cfg.batch_size)
            batch = batch.to(self.cfg.device)
            
            obs = batch.obs
            next_obs = batch.next_obs
            
            # Compute goal-based reward
            goal_reward = reward_fn(next_obs)
            
            for skill_idx in range(self.cfg.num_skills):
                # Create one-hot skill vector
                skill = torch.zeros(obs.size(0), self.cfg.num_skills, device=self.cfg.device)
                skill[:, skill_idx] = 1.0
                
                # Compute intrinsic reward for this skill
                # intrinsic_reward = self.discriminator.get_intrinsic_reward(obs, next_obs, skill)
                
                # Combine intrinsic and goal rewards
                total_reward = intrinsic_reward + goal_reward
                skill_rewards[skill_idx].append(total_reward.mean().item())
        
        # Compute average rewards for each skill
        avg_rewards = [np.mean(rewards) for rewards in skill_rewards]
        
        # Find the best skill
        best_skill_idx = np.argmax(avg_rewards)
        best_skill_reward = avg_rewards[best_skill_idx]
        
        # Store the best skill for this goal
        self.best_skill = self.skill_onehot[best_skill_idx]
        
        metrics.update({
            "best_skill_idx": best_skill_idx,
            "best_skill_reward": best_skill_reward,
            "skill_rewards": avg_rewards,
            "skill_diversity": np.std(avg_rewards),
            "goal_achievement": best_skill_reward
        })
        
        return metrics

    def q_function_inference(self, obs: torch.Tensor, q_pos: torch.Tensor, q_neg: torch.Tensor, 
                           z: tp.Dict[str, float]) -> torch.Tensor:
        """Get Q-values from skill policies."""
        # If we have a current skill, use its policy
        skill_policy = self.skill_policies[z]
        return skill_policy.q_function(obs)
        # if self.current_skill_idx is not None:
        #     skill_policy = self.skill_policies[self.current_skill_idx]
        #     return skill_policy.q_function(obs)
        # else:
        #     # Return zeros if no skill is selected
        #     return torch.zeros(obs.size(0), self.action_dim, device=obs.device)
    
    def get_skill_q_values(self, obs: torch.Tensor, skill_idx: int) -> torch.Tensor:
        """Get Q-values for a specific skill."""
        if skill_idx in self.skill_policies:
            skill_policy = self.skill_policies[skill_idx]
            return skill_policy.q_function(obs)
        else:
            return torch.zeros(obs.size(0), self.action_dim, device=obs.device)

    def infer_w_goal(self, replay_loader: ReplayBuffer, goal: np.ndarray) -> tp.Dict[str, float]:
        """Infer goal representation (optional method)."""
        return {}

    def distill_actor_ddpg(self, replay_loader: ReplayBuffer, logger, goal: np.ndarray) -> tp.Dict[str, float]:
        """Distill actor using DDPG (optional method)."""
        return {}


