"""
DIAYN (Diversity is All You Need) implementation.
This file implements DIAYN based on the paper "Diversity is All You Need: Learning Skills without a Reward Function"
(https://arxiv.org/abs/1802.06070)

This implementation also supports SMERL (Structured Maximum Entropy Reinforcement Learning)
following "SMiRL: Surprise Minimizing Reinforcement Learning in Unstable Environments"
(https://arxiv.org/abs/2010.14484)

This implementation is inspired by
1. https://github.com/Egiob/DiversityIsAllYouNeed-SB3/blob/master/stable_baselines3/diayn/diayn.py
2. https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py
"""

import copy
import logging
import os
from pathlib import Path
from typing import List, Optional, Union

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from omegaconf import DictConfig, OmegaConf

from src.logging_utils import MetricLogger

# Constants for the SAC implementation
LOG_STD_MAX = 2
LOG_STD_MIN = -5


class SACPolicy(nn.Module):
    """
    SAC Policy implementing both actor and critic networks.
    The actor is a stochastic Gaussian policy with state-dependent variance.
    The critics are dual Q-networks with target networks.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: List[int],
        action_space: Optional[gym.spaces.Box] = None,
        device: str = "cpu",
    ):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = device

        # Actor network (stochastic policy)
        layers = []
        for _in, _out in zip([state_dim] + hidden_dims, hidden_dims):
            layers.append(nn.Linear(_in, _out))
            layers.append(nn.ReLU())
        self.actor_net = nn.Sequential(*layers)
        self.mean_layer = nn.Linear(hidden_dims[-1], action_dim)
        self.log_std_layer = nn.Linear(hidden_dims[-1], action_dim)

        # Critic networks (Q-functions)
        layers = []
        for _in, _out in zip([state_dim + action_dim] + hidden_dims, hidden_dims):
            layers.append(nn.Linear(_in, _out))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dims[-1], 1))
        self.q1_net = nn.Sequential(*layers)

        layers = []
        for _in, _out in zip([state_dim + action_dim] + hidden_dims, hidden_dims):
            layers.append(nn.Linear(_in, _out))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dims[-1], 1))
        self.q2_net = nn.Sequential(*layers)

        # Target networks
        self.q1_target = copy.deepcopy(self.q1_net)
        self.q2_target = copy.deepcopy(self.q2_net)

        # Save action space bounds for scaling actions
        self.action_scale = 1.0
        self.action_bias = 0.0
        if action_space is not None:
            self.action_scale = torch.FloatTensor(
                (action_space.high - action_space.low) / 2.0
            ).to(device)
            self.action_bias = torch.FloatTensor(
                (action_space.high + action_space.low) / 2.0
            ).to(device)

        self.to(device)

    def forward_actor(self, state):
        """Forward pass through actor network"""
        x = self.actor_net(state)
        mean = self.mean_layer(x)
        log_std = self.log_std_layer(x)
        # Constrain log_std within reasonable range
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        return mean, log_std

    def forward_critic(self, state, action):
        """Forward pass through critic networks"""
        sa = torch.cat([state, action], dim=1)
        q1 = self.q1_net(sa)
        q2 = self.q2_net(sa)
        return q1, q2

    def forward_critic_target(self, state, action):
        """Forward pass through target critic networks"""
        sa = torch.cat([state, action], dim=1)
        q1 = self.q1_target(sa)
        q2 = self.q2_target(sa)
        return q1, q2

    def sample_action(self, state):
        """
        Sample action from the policy using the reparameterization trick.
        Returns the action, log probability, and mean action.
        """
        mean, log_std = self.forward_actor(state)
        std = log_std.exp()

        # Sample using reparameterization trick
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias

        # Compute log probability, using the chain rule to get log prob in the original action space
        log_prob = normal.log_prob(x_t)
        # Correction for the tanh squashing
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)

        # Also compute the deterministic action (mean) for evaluation
        mean_action = torch.tanh(mean) * self.action_scale + self.action_bias

        return action, log_prob, mean_action

    def update_targets(self, tau):
        """Update target networks using polyak averaging"""
        for param, target_param in zip(
            self.q1_net.parameters(), self.q1_target.parameters()
        ):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)
        for param, target_param in zip(
            self.q2_net.parameters(), self.q2_target.parameters()
        ):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)


class DIAYNDiscriminator(nn.Module):
    """
    Discriminator network for DIAYN.
    Takes a state as input and outputs a probability distribution over skills.
    """

    def __init__(
        self,
        state_dim: int,
        n_skills: int,
        hidden_dims: List[int],
        device: str = "cpu",
    ):
        super().__init__()
        self.state_dim = state_dim
        self.n_skills = n_skills
        self.device = device

        # Build network
        layers = []
        for _in, _out in zip([state_dim] + hidden_dims, hidden_dims):
            layers.append(nn.Linear(_in, _out))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dims[-1], n_skills))
        self.net = nn.Sequential(*layers)

        self.to(device)

    def forward(self, state):
        """
        Forward pass through discriminator.
        Returns log probabilities of each skill given the state.
        """
        logits = self.net(state)
        log_probs = F.log_softmax(logits, dim=1)
        return log_probs


class DIAYNReplayBuffer:
    """
    Replay buffer for DIAYN storing state, next_state, action, reward, done, skill.
    """

    def __init__(
        self,
        buffer_size: int,
        state_dim: int,
        action_dim: int,
        n_skills: int,
        device: str = "cpu",
    ):
        self.buffer_size = buffer_size
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.n_skills = n_skills
        self.device = device

        # Initialize buffers
        self.states = np.zeros((buffer_size, state_dim), dtype=np.float32)
        self.next_states = np.zeros((buffer_size, state_dim), dtype=np.float32)
        self.actions = np.zeros((buffer_size, action_dim), dtype=np.float32)
        self.rewards = np.zeros((buffer_size, 1), dtype=np.float32)
        self.dones = np.zeros((buffer_size, 1), dtype=np.float32)
        self.skills = np.zeros((buffer_size, n_skills), dtype=np.float32)

        self.ptr = 0
        self.size = 0

    def add_batch(
        self,
        states: np.ndarray,
        next_states: np.ndarray,
        actions: np.ndarray,
        rewards: float,
        dones: float,
        skills: np.ndarray,
    ):
        """Add batch of transitions to buffer"""
        bsz = states.shape[0]
        num_left = self.buffer_size - self.ptr
        if num_left >= bsz:
            # Just add them
            self.states[self.ptr : self.ptr + bsz] = states
            self.next_states[self.ptr : self.ptr + bsz] = next_states
            self.actions[self.ptr : self.ptr + bsz] = actions
            self.rewards[self.ptr : self.ptr + bsz] = np.expand_dims(rewards, -1)
            self.dones[self.ptr : self.ptr + bsz] = np.expand_dims(dones, -1)
            self.skills[self.ptr : self.ptr + bsz] = skills
            self.ptr += bsz
            self.ptr %= self.buffer_size
            self.size = min(self.size + bsz, self.buffer_size)
        else:
            # Fill up remaining space first, then add the rest
            self.states[self.ptr :] = states[:num_left]
            self.next_states[self.ptr :] = next_states[:num_left]
            self.actions[self.ptr :] = actions[:num_left]
            self.rewards[self.ptr :] = np.expand_dims(rewards[:num_left], -1)
            self.dones[self.ptr :] = np.expand_dims(dones[:num_left], -1)
            self.skills[self.ptr :] = skills[:num_left]
            self.ptr = 0
            self.size = self.buffer_size
            self.add_batch(
                states[num_left:],
                next_states[num_left:],
                actions[num_left:],
                rewards[num_left:],
                dones[num_left:],
                skills[num_left:],
            )

    def add(
        self,
        state: np.ndarray,
        next_state: np.ndarray,
        action: np.ndarray,
        reward: float,
        done: float,
        skill: np.ndarray,
    ):
        """Add transition to buffer"""
        self.states[self.ptr] = state
        self.next_states[self.ptr] = next_state
        self.actions[self.ptr] = action
        self.rewards[self.ptr] = reward
        self.dones[self.ptr] = done
        self.skills[self.ptr] = skill

        self.ptr = (self.ptr + 1) % self.buffer_size
        self.size = min(self.size + 1, self.buffer_size)

    def sample(self, batch_size: int):
        """Sample batch of transitions"""
        indices = np.random.randint(0, self.size, size=batch_size)

        return (
            torch.FloatTensor(self.states[indices]).to(self.device),
            torch.FloatTensor(self.next_states[indices]).to(self.device),
            torch.FloatTensor(self.actions[indices]).to(self.device),
            torch.FloatTensor(self.rewards[indices]).to(self.device),
            torch.FloatTensor(self.dones[indices]).to(self.device),
            torch.FloatTensor(self.skills[indices]).to(self.device),
        )


class DIAYN:
    """
    DIAYN implementation (Diversity is All You Need) with optional SMERL functionality.

    DIAYN learns diverse skills without extrinsic rewards by maximizing:
    1. Mutual information between skills and states
    2. Entropy of the policy
    3. Minimizing mutual information between skills and next states given current states

    SMERL extends this by incorporating extrinsic rewards when the policy achieves
    a certain performance threshold, encouraging diversity only among high-performing policies.
    """

    def __init__(
        self,
        env_cfg: DictConfig,
        actor_lr: float,
        critic_lr: float,
        disc_lr: float,
        alpha_lr: float,
        n_skills: int,
        buffer_size: int,
        batch_size: int,
        hidden_dims: List[int],
        gamma: float,
        tau: float,
        alpha: Union[float, str],
        learning_starts: int,
        policy_freq: int,
        target_update_freq: int,
        # DIAYN/SMERL specific
        combined_rewards: bool,
        beta: float,
        smerl_threshold: Optional[float],
        smerl_eps: float,
        use_skill_prior: bool,
        # Training parameters
        exp_dir: Path,
        total_iterations: int,
        eval_freq: int,
        n_eval_episodes: int,
        save_freq: int,
        log_freq: int,
        device: str,
        seed: int,
        metric_logger: Optional[MetricLogger],
        info_logger: Optional[logging.Logger],
        **kwargs,
    ):
        """
        Initialize DIAYN

        Args:
            env_cfg: Environment configuration
            actor_lr: Actor learning rate
            critic_lr: Critic learning rate
            disc_lr: Discriminator learning rate
            alpha_lr: Alpha (entropy coefficient) learning rate
            n_skills: Number of skills to learn
            buffer_size: Replay buffer capacity
            batch_size: Training batch size
            hidden_dims: Hidden dimensions for networks
            gamma: Discount factor
            tau: Target network update rate
            alpha: Entropy coefficient (or "auto" for automatic tuning)
            learning_starts: Number of steps before learning starts
            policy_freq: Policy network update frequency
            target_update_freq: Target network update frequency

            combined_rewards: If True, combine environment and diversity rewards (SMERL)
            beta: Weight for diversity reward in SMERL
            smerl_threshold: Performance threshold for SMERL
            smerl_eps: Margin for SMERL threshold
            use_skill_prior: Whether to use the skill prior in reward calculation

            exp_dir: Directory to save results
            total_iterations: Total training iterations
            eval_freq: Evaluation frequency
            n_eval_episodes: Number of episodes for evaluation
            save_freq: Model saving frequency
            log_freq: Logging frequency
            device: Device to run on (cpu/cuda)
            seed: Random seed
            metric_logger: Logger for metrics
            info_logger: Logger for info messages
        """
        self.env_cfg = env_cfg
        self.num_envs = env_cfg.num_envs
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.disc_lr = disc_lr
        self.alpha_lr = alpha_lr
        self.n_skills = n_skills
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.hidden_dims = hidden_dims
        self.gamma = gamma
        self.tau = tau
        self.alpha_init = alpha
        self.learning_starts = learning_starts
        self.policy_freq = policy_freq
        self.target_update_freq = target_update_freq

        self.combined_rewards = combined_rewards
        self.beta = beta
        self.smerl_threshold = smerl_threshold
        self.smerl_eps = smerl_eps
        self.use_skill_prior = use_skill_prior

        self.exp_dir = exp_dir
        self.total_iterations = total_iterations
        self.eval_freq = eval_freq
        self.n_eval_episodes = n_eval_episodes
        self.save_freq = save_freq
        self.log_freq = log_freq
        self.device = device
        self.seed = seed
        self.logger = metric_logger
        self.info_logger = info_logger

        if self.smerl_threshold is not None:
            self.combined_rewards = True
            if self.info_logger:
                self.info_logger.info(
                    "SMERL mode enabled with threshold: "
                    + f"{self.smerl_threshold} and margin: {self.smerl_eps}"
                )

        self._set_seed(seed)
        self._setup_env()
        self._setup_model()

        # Initialize tracking variables
        self.num_timesteps = 0
        self.num_episodes = 0
        self.episode_rewards = []
        self.episode_env_rewards = []
        self.episode_lengths = []
        self.skill_episode_rewards = [[] for _ in range(self.n_skills)]

    def _set_seed(self, seed):
        """Set random seeds for reproducibility"""
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)

    def _setup_env(self):
        """Setup training environments"""
        env_kwargs = (
            OmegaConf.to_container(self.env_cfg.env_kwargs, resolve=True)
            if "env_kwargs" in self.env_cfg
            else {}
        )

        def make_env(env_id, seed, env_kwargs):
            def thunk():
                env = gym.make(env_id, **env_kwargs)
                env = gym.wrappers.RecordEpisodeStatistics(env)
                env.action_space.seed(seed)
                return env

            return thunk

        self.envs = gym.vector.AsyncVectorEnv(
            [
                make_env(self.env_cfg.env_id, i, env_kwargs)
                for i in range(self.env_cfg.num_envs)
            ]
        )

        self.state_dim = self.env_cfg.state_dim
        self.action_dim = self.env_cfg.action_dim

        # Create evaluation environment
        self.eval_env = gym.make(self.env_cfg.env_id, **env_kwargs)

    def _setup_model(self):
        """Initialize model components"""
        # Policy (actor-critic)
        self.policy = SACPolicy(
            state_dim=self.state_dim + self.n_skills,  # Augment state with skill
            action_dim=self.action_dim,
            hidden_dims=self.hidden_dims,
            action_space=self.envs.single_action_space,
            device=self.device,
        )

        # Discriminator
        self.discriminator = DIAYNDiscriminator(
            state_dim=self.state_dim,
            n_skills=self.n_skills,
            hidden_dims=self.hidden_dims,
            device=self.device,
        )

        # Optimizers
        self.actor_optimizer = optim.Adam(
            list(self.policy.actor_net.parameters())
            + list(self.policy.mean_layer.parameters())
            + list(self.policy.log_std_layer.parameters()),
            lr=self.actor_lr,
        )

        self.critic_optimizer = optim.Adam(
            list(self.policy.q1_net.parameters())
            + list(self.policy.q2_net.parameters()),
            lr=self.critic_lr,
        )

        self.disc_optimizer = optim.Adam(
            self.discriminator.parameters(), lr=self.disc_lr
        )

        # Automatic entropy tuning
        if self.alpha_init == "auto":
            self.target_entropy = -np.prod(self.envs.single_action_space.shape).astype(
                np.float32
            )
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha = torch.exp(self.log_alpha).item()
            self.alpha_optimizer = optim.Adam([self.log_alpha], lr=self.alpha_lr)
        else:
            self.alpha = float(self.alpha_init)
            self.target_entropy = None
            self.log_alpha = None
            self.alpha_optimizer = None

        # Replay buffer
        self.replay_buffer = DIAYNReplayBuffer(
            buffer_size=self.buffer_size,
            state_dim=self.state_dim,
            action_dim=self.action_dim,
            n_skills=self.n_skills,
            device=self.device,
        )

        # Create skill prior
        self.log_skill_prior = torch.zeros(
            self.n_skills, device=self.device
        )  # Uniform prior

        # Create directories
        if self.exp_dir is not None:
            os.makedirs(self.exp_dir / "checkpoints", exist_ok=True)
            os.makedirs(self.exp_dir / "policies", exist_ok=True)

    def compute_diayn_reward(
        self, states: torch.tensor, skill_indices: np.ndarray
    ) -> np.ndarray:
        """
        Compute DIAYN intrinsic reward: log q(z|s) - log p(z)

        Args:
            states: State to evaluate
            skill_indices: Active skill index

        Returns:
            DIAYN intrinsic reward
        """
        with torch.no_grad():
            state_tensor = torch.FloatTensor(states).to(self.device)
            log_q_z_given_s = self.discriminator(state_tensor)

            # Get log probability of current skill
            disc_rewards = (
                log_q_z_given_s.gather(
                    1, torch.tensor(skill_indices, device=self.device).unsqueeze(1)
                )
                .flatten()
                .cpu()
                .numpy()
            )

            # Subtract log of skill prior if enabled
            if self.use_skill_prior:
                disc_rewards -= self.log_skill_prior[skill_indices].cpu().numpy()

            return disc_rewards

    def compute_reward(self, env_rewards, diayn_rewards, skill_indices):
        """
        Compute final reward based on environment and diversity rewards

        DIAYN: Return only diversity reward
        SMERL: Return env_reward + beta * diayn_reward if skill performs above threshold
        Combined: Return env_reward + beta * diayn_reward
        """
        if not self.combined_rewards:
            # Pure DIAYN mode
            return diayn_rewards

        if self.smerl_threshold is not None:
            rewards = env_rewards.copy()
            for i, skill_idx in enumerate(skill_indices):
                # SMERL mode
                if len(self.skill_episode_rewards[skill_idx]) > 0:
                    # Get average reward for this skill
                    avg_reward = np.mean(self.skill_episode_rewards[skill_idx][-10:])

                    # Compute threshold with margin
                    eps_threshold = self.smerl_threshold - abs(
                        self.smerl_eps * self.smerl_threshold
                    )

                    # Apply diversity reward only if we're above threshold
                    if avg_reward >= eps_threshold:
                        rewards[i] += self.beta * diayn_rewards[i]
                else:
                    # No history yet, apply normal combined reward
                    rewards[i] += self.beta * diayn_rewards[i]
            return rewards
        else:
            # Simple combined mode
            return env_rewards + self.beta * diayn_rewards

    def select_action(self, states, skills, evaluate=False):
        """
        Select action using policy

        Args:
            states: Environment state
            skills: One-hot skill vector
            evaluate: If True, use deterministic policy

        Returns:
            Selected actions
        """
        with torch.no_grad():
            # Combine state with skill
            state_skills = torch.FloatTensor(
                np.concatenate([states, skills], axis=-1)
            ).to(self.device)
            is_batched = True
            if state_skills.ndim == 1:
                state_skills = state_skills.unsqueeze(0)
                is_batched = False

            if evaluate:
                # Use mean action for evaluation
                _, _, actions = self.policy.sample_action(state_skills)
            else:
                # Sample stochastic action for training
                actions, _, _ = self.policy.sample_action(state_skills)

            return (
                actions.cpu().numpy() if is_batched else actions.cpu().numpy().flatten()
            )

    def sample_skill(self):
        """
        Uniformly samples a batch of skill

        Returns:
            skills: One-hot encoded batch of skills
            skill_indices: Index of the skills
        """
        skill_indices = np.random.randint(0, self.n_skills, self.num_envs)
        skills = np.eye(self.n_skills)[skill_indices]

        return skills, skill_indices

    def update(self, counter):
        """Update all networks (critic, actor, discriminator, alpha)"""
        if self.replay_buffer.size < self.batch_size:
            return {}, {}

        states, next_states, actions, rewards, dones, skills = (
            self.replay_buffer.sample(self.batch_size)
        )

        # Get skill indices (for discriminator update)
        skill_indices = torch.argmax(skills, dim=1)

        # -------------------------
        # 1. Update Discriminator
        # -------------------------
        self.disc_optimizer.zero_grad()
        # Get log probabilities for each skill given the state
        log_q_z_given_s = self.discriminator(states)
        # Compute cross-entropy loss
        disc_loss = F.nll_loss(log_q_z_given_s, skill_indices)
        disc_loss.backward()
        self.disc_optimizer.step()

        # -------------------------
        # 2. Update Critic
        # -------------------------
        # Combine state with skill for SAC
        state_skills = torch.cat([states, skills], dim=1)
        next_state_skills = torch.cat([next_states, skills], dim=1)

        # Compute target Q-values
        with torch.no_grad():
            # Get next actions and log probs from current policy
            next_actions, next_log_probs, _ = self.policy.sample_action(
                next_state_skills
            )
            # Get target Q-values
            q1_next, q2_next = self.policy.forward_critic_target(
                next_state_skills, next_actions
            )
            # Take minimum of both Q-values
            q_next = torch.min(q1_next, q2_next)
            # Subtract entropy term
            q_next = q_next - self.alpha * next_log_probs
            # Compute target using Bellman equation
            q_target = rewards + (1.0 - dones) * self.gamma * q_next

        # Get current Q-values
        q1, q2 = self.policy.forward_critic(state_skills, actions)

        # Compute MSE loss for both critics
        q1_loss = F.mse_loss(q1, q_target)
        q2_loss = F.mse_loss(q2, q_target)
        critic_loss = q1_loss + q2_loss

        # Update critics
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # -------------------------
        # 3. Update Actor
        # -------------------------
        # Only update actor every policy_freq steps
        if (self.num_timesteps * self.num_envs + counter) % self.policy_freq == 0:
            # Get current actions and log probs from policy
            actions, log_probs, _ = self.policy.sample_action(state_skills)

            # Get Q-values for current actions
            q1, q2 = self.policy.forward_critic(state_skills, actions)
            min_q = torch.min(q1, q2)

            # Compute policy loss (negative of Q-value minus entropy term)
            actor_loss = ((self.alpha * log_probs) - min_q).mean()

            # Update actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # -------------------------
            # 4. Update Alpha (if automatic)
            # -------------------------
            if self.alpha_optimizer is not None:
                # Compute alpha loss
                alpha_loss = (
                    -self.log_alpha.exp() * (log_probs.detach() + self.target_entropy)
                ).mean()

                # Update alpha
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()

                # Update alpha value
                self.alpha = self.log_alpha.exp().item()
        else:
            actor_loss = torch.tensor(0.0)
            alpha_loss = torch.tensor(0.0) if self.alpha_optimizer is not None else None

        # -------------------------
        # 5. Update Target Networks
        # -------------------------
        if (
            self.num_timesteps * self.num_envs + counter
        ) % self.target_update_freq == 0:
            self.policy.update_targets(self.tau)

        # Collect metrics for logging
        train_metrics = {
            "loss/discriminator": disc_loss.item(),
            "loss/critic": critic_loss.item(),
            "loss/actor": actor_loss.item(),
            "loss/alpha": alpha_loss.item() if alpha_loss is not None else 0.0,
            "train/alpha": self.alpha,
            "train/mean_q": ((q1.mean() + q2.mean()) / 2).item(),
        }

        # Compute accuracy of the discriminator
        with torch.no_grad():
            disc_preds = torch.argmax(log_q_z_given_s, dim=1)
            disc_accuracy = (disc_preds == skill_indices).float().mean().item()

        info_metrics = {"discriminator/accuracy": disc_accuracy}

        return train_metrics, info_metrics

    def train(self):
        """Main training loop"""
        # Reset environment
        states, _ = self.envs.reset(seed=self.seed)
        autoreset = np.zeros(self.num_envs, dtype=bool)

        # Sample initial skill
        skills, skill_indices = self.sample_skill()

        # Training statistics
        episode_timesteps = np.zeros(self.num_envs)
        episode_reward = np.zeros(self.num_envs)
        episode_env_reward = np.zeros(self.num_envs)
        episode_diayn_reward = np.zeros(self.num_envs)

        if self.info_logger:
            self.info_logger.info(
                f"Starting training with {self.n_skills} skills for {self.total_iterations} timesteps"
            )
            if self.combined_rewards:
                mode = "SMERL" if self.smerl_threshold is not None else "Combined"
                self.info_logger.info(f"Mode: {mode} (beta={self.beta})")
            else:
                self.info_logger.info("Mode: Pure DIAYN")

        for t in range(1, self.total_iterations + 1):
            self.num_timesteps += 1
            episode_timesteps += 1

            # Select action
            if self.num_timesteps < self.learning_starts:
                # Random exploration initially
                actions = self.envs.action_space.sample()
            else:
                # Use policy
                actions = self.select_action(states, skills)

            # Step environment
            next_states, env_rewards, terms, truncs, info = self.envs.step(actions)
            dones = terms

            # Compute DIAYN reward
            diayn_rewards = self.compute_diayn_reward(next_states, skill_indices)

            # Compute final reward (DIAYN, SMERL, or combined)
            rewards = self.compute_reward(env_rewards, diayn_rewards, skill_indices)

            # Store in replay buffer TODO
            self.replay_buffer.add_batch(
                states=states,
                next_states=next_states,
                actions=actions,
                rewards=rewards,
                dones=dones.astype(float),
                skills=skills,
            )

            # Update state and tracking variables
            states = next_states
            episode_reward += rewards
            episode_env_reward += env_rewards
            episode_diayn_reward += diayn_rewards

            # Update networks
            if self.num_timesteps >= self.learning_starts:
                # Update more frequently because more envs
                for counter in range(self.num_envs):
                    train_metrics, info_metrics = self.update(counter)

                # Log training metrics
                if self.logger is not None and t % self.log_freq == 0:
                    self.logger.update(train_metrics)
                    self.logger.update(info_metrics)
                    self.logger.log(self.num_timesteps)

            # End of episode handling
            autoreset = np.logical_or(terms, truncs)
            for i in range(self.num_envs):
                if autoreset[i]:
                    self.num_episodes += 1
                    self.episode_rewards.append(episode_reward[i])
                    self.episode_env_rewards.append(episode_env_reward[i])
                    self.episode_lengths.append(episode_timesteps[i])
                    self.skill_episode_rewards[skill_indices[i]].append(
                        episode_env_reward[i]
                    )

                    # Log episode information
                    if self.info_logger:
                        self.info_logger.info(
                            f"Episode {self.num_episodes} (Skill {skill_indices[i]}) | "
                            f"Length: {episode_timesteps[i]} | "
                            f"Env Reward: {episode_env_reward[i]:.2f} | "
                            f"DIAYN Reward: {episode_diayn_reward[i]:.2f} | "
                            f"Total Reward: {episode_reward[i]:.2f}"
                        )

                    # Log to metrics logger
                    if self.logger is not None and i == 0:
                        # Only log the first episode to wandb to avoid having mulitple datum for a timestep
                        self.logger.update(
                            {
                                f"episode/reward": episode_reward[i],
                                f"episode/env_reward": episode_env_reward[i],
                                f"episode/diayn_reward": episode_diayn_reward[i],
                                f"episode/length": episode_timesteps[i],
                                f"skill/{skill_indices[i]}/reward": episode_env_reward[
                                    i
                                ],
                            }
                        )
                        self.logger.log(self.num_timesteps)

                    # Sample a new skill for next episode
                    skill_indices[i] = np.random.randint(0, self.n_skills)
                    skills[i] = np.eye(self.n_skills)[skill_indices[i]]

                    # Reset episode tracking variables
                    episode_reward[i] = 0
                    episode_env_reward[i] = 0
                    episode_diayn_reward[i] = 0
                    episode_timesteps[i] = 0

            # Evaluation
            # if self.eval_freq > 0 and t % self.eval_freq == 0:
            #     self.evaluate()

            # Save checkpoints
            if self.save_freq > 0 and t % self.save_freq == 0:
                self.save(f"checkpoint_{t}")

        # Final evaluation and saving
        # if self.eval_freq > 0:
        #     self.evaluate()

        if self.exp_dir is not None:
            self.save("final")

        # Log training completion
        if self.info_logger:
            self.info_logger.info(
                f"Training completed after {self.num_episodes} episodes"
            )
            for i in range(self.n_skills):
                avg_reward = np.mean(self.skill_episode_rewards[i][-10:])
                self.info_logger.info(f"Skill {i} final reward: {avg_reward:.2f}")

    def evaluate(self):
        """
        Evaluate all skills

        Performs n_eval_episodes for each skill and reports average reward
        """
        if self.info_logger:
            self.info_logger.info("\nEvaluating skills...")

        eval_rewards = []
        eval_lengths = []
        disc_rewards = []

        # Evaluate each skill
        for skill_idx in range(self.n_skills):
            skill = np.zeros(self.n_skills, dtype=np.float32)
            skill[skill_idx] = 1.0

            skill_rewards = []
            skill_lengths = []
            skill_disc_rewards = []

            for ep in range(self.n_eval_episodes):
                state, _ = self.eval_env.reset(seed=self.seed + skill_idx + ep)
                done = False
                episode_reward = 0
                episode_length = 0
                episode_disc_reward = 0

                while not done:
                    # Select deterministic action
                    action = self.select_action(state, skill, evaluate=True)

                    # Step environment
                    next_state, reward, terminated, truncated, _ = self.eval_env.step(
                        action
                    )
                    done = terminated or truncated

                    # Compute discriminator reward
                    disc_reward = self.compute_diayn_reward(next_state, skill_idx)

                    # Update tracking variables
                    episode_reward += reward
                    episode_disc_reward += disc_reward
                    episode_length += 1
                    state = next_state

                skill_rewards.append(episode_reward)
                skill_lengths.append(episode_length)
                skill_disc_rewards.append(episode_disc_reward)

            # Compute averages
            avg_reward = np.mean(skill_rewards)
            avg_length = np.mean(skill_lengths)
            avg_disc_reward = np.mean(skill_disc_rewards)

            eval_rewards.append(avg_reward)
            eval_lengths.append(avg_length)
            disc_rewards.append(avg_disc_reward)

            if self.info_logger:
                self.info_logger.info(
                    f"Skill {skill_idx}: "
                    f"Reward={avg_reward:.2f}, "
                    f"Length={avg_length:.2f}, "
                    f"Disc Reward={avg_disc_reward:.2f}"
                )

            # Log to metrics logger
            if self.logger is not None:
                self.logger.update(
                    {
                        f"eval/skill_{skill_idx}_reward": avg_reward,
                        f"eval/skill_{skill_idx}_length": avg_length,
                        f"eval/skill_{skill_idx}_disc_reward": avg_disc_reward,
                    }
                )

        # Log overall statistics
        overall_avg_reward = np.mean(eval_rewards)
        overall_avg_length = np.mean(eval_lengths)
        overall_avg_disc_reward = np.mean(disc_rewards)

        if self.info_logger:
            self.info_logger.info(
                f"Overall: "
                f"Reward={overall_avg_reward:.2f}, "
                f"Length={overall_avg_length:.2f}, "
                f"Disc Reward={overall_avg_disc_reward:.2f}\n"
            )

        if self.logger is not None:
            self.logger.update(
                {
                    "eval/mean_reward": overall_avg_reward,
                    "eval/mean_length": overall_avg_length,
                    "eval/mean_disc_reward": overall_avg_disc_reward,
                }
            )
            self.logger.log(self.num_timesteps)

        return eval_rewards

    def save(self, filename):
        """Save model checkpoint"""
        if self.exp_dir is None:
            return

        path = self.exp_dir / "checkpoints" / f"{filename}.pt"
        checkpoint = {
            "policy": copy.deepcopy(self.policy).to("cpu").state_dict(),
            "discriminator": copy.deepcopy(self.discriminator).to("cpu").state_dict(),
            "n_skills": self.n_skills,
            "alpha": self.alpha,
            "num_timesteps": self.num_timesteps,
            "num_episodes": self.num_episodes,
            "state_dim": self.state_dim,
            "action_dim": self.action_dim,
            "hidden_dims": self.hidden_dims,
            "seed": self.seed,
        }

        torch.save(checkpoint, path)

        if self.info_logger:
            self.info_logger.info(f"Model saved to {path}")
