import warnings

import gym
import numpy as np
import torch
from torch.nn import functional as F
import torch.nn as nn
import time

from models.decoder_dme import StateTransitionDecoder_dme, RewardDecoder_dme, TaskDecoder_dme
from models.encoder_dme import RNNEncoder_dme
from utils.helpers import get_task_dim, get_num_tasks, get_device
from utils.storage_vae import RolloutStorageVAE

import matplotlib.pyplot as plt
from pathlib import Path
from utils.helpers import get_device
from torch.utils.data import DataLoader
from torch.distributions.kl import kl_divergence


class VaribadVAEDME:
    """
    New GMVAE following Dilokthanakul et al 2017:
    - qφ(w | h) - style (Gaussian)  
    - qφ(z | h) - skill (Gaussian)
    - qφ(y | w, z) - analytic ∝ pθ(z|w,y) p(y)
    - can compute the ELBO loss with new formulation
    - can update the GMVAE (encoder+decoder)
    """

    def __init__(self, args, logger, get_iter_idx, error_handling=False):

        self.args = args
        self.logger = logger
        self.get_iter_idx = get_iter_idx
        self.task_dim = get_task_dim(self.args) if self.args.decode_task else None
        self.num_tasks = get_num_tasks(self.args) if self.args.decode_task else None
        self.dummy = 0

        # initialise the encoder
        self.encoder = self.initialise_encoder()

        # initialise the decoders (returns None for unused decoders)
        self.state_decoder, self.reward_decoder, self.task_decoder = self.initialise_decoder()

        # initialise rollout storage for the VAE update
        self.rollout_storage = RolloutStorageVAE(num_processes=self.args.num_processes,
                                                 max_trajectory_len=self.args.max_trajectory_len,
                                                 zero_pad=True,
                                                 max_num_rollouts=self.args.size_vae_buffer,
                                                 state_dim=self.args.state_dim,
                                                 action_dim=self.args.action_dim,
                                                 vae_buffer_add_thresh=self.args.vae_buffer_add_thresh,
                                                 task_dim=self.task_dim,
                                                 error_handling=error_handling
                                                 )

        self.rollout_storage_virtual = RolloutStorageVAE(num_processes=self.args.num_processes,
                                                 max_trajectory_len=self.args.max_trajectory_len,
                                                 zero_pad=True,
                                                 max_num_rollouts=0,
                                                 state_dim=self.args.state_dim,
                                                 action_dim=self.args.action_dim,
                                                 vae_buffer_add_thresh=self.args.vae_buffer_add_thresh,
                                                 task_dim=self.task_dim,
                                                 error_handling=error_handling
                                                 )

        # Optional external "style" encoder (only loaded at test-time)
        self.style_net = None

        # initalise optimiser for the encoder and decoders
        decoder_params = []
        if not self.args.disable_decoder:
            if self.args.decode_reward:
                decoder_params.extend(self.reward_decoder.parameters())
            if self.args.decode_state:
                decoder_params.extend(self.state_decoder.parameters())
            if self.args.decode_task:
                decoder_params.extend(self.task_decoder.parameters())
        self.optimiser_vae = torch.optim.Adam([*self.encoder.parameters(), *decoder_params], lr=self.args.lr_vae)

    def initialise_encoder(self):
        """ Initialises and returns an RNN encoder """
        encoder = RNNEncoder_dme(
            args=self.args,
            layers_before_gru=self.args.encoder_layers_before_gru,
            hidden_size=self.args.encoder_gru_hidden_size,
            layers_after_gru=self.args.encoder_layers_after_gru,
            class_dim = self.args.vae_mixture_num,
            latent_dim=self.args.latent_dim,
            action_dim=self.args.action_dim,
            action_embed_dim=self.args.action_embedding_size,
            state_dim=self.args.state_dim,
            state_embed_dim=self.args.state_embedding_size,
            reward_size=1,
            reward_embed_size=self.args.reward_embedding_size,
        ).to(get_device())
        return encoder

    def initialise_decoder(self):
        """ Initialises and returns the (state/reward/task) decoder as specified in self.args """

        if self.args.disable_decoder:
            return None, None, None

        latent_dim = self.args.latent_dim
        # if we don't sample embeddings for the decoder, we feed in mean & variance
        if self.args.disable_stochasticity_in_latent:
            latent_dim *= 2

        # initialise state decoder for VAE
        if self.args.decode_state:
            state_decoder = StateTransitionDecoder_dme(
                args=self.args,
                layers=self.args.state_decoder_layers,
                class_dim = self.args.vae_mixture_num,
                latent_dim=latent_dim,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embedding_size,
                state_dim=self.args.state_dim,
                state_embed_dim=self.args.state_embedding_size,
                pred_type=self.args.state_pred_type,
                dropout_rate = self.args.dropout_rate
            ).to(get_device())
        else:
            state_decoder = None

        # initialise reward decoder for VAE
        if self.args.decode_reward:
            reward_decoder = RewardDecoder_dme(
                args=self.args,
                layers=self.args.reward_decoder_layers,
                class_dim=self.args.vae_mixture_num,
                latent_dim=latent_dim,
                state_dim=self.args.state_dim,
                state_embed_dim=self.args.state_embedding_size,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embedding_size,
                num_states=self.args.state_dim,
                multi_head=self.args.multihead_for_reward,
                pred_type=self.args.rew_pred_type,
                input_prev_state=self.args.input_prev_state,
                input_action=self.args.input_action,
                dropout_rate=self.args.dropout_rate
            ).to(get_device())
        else:
            reward_decoder = None

        # initialise task decoder for VAE
        if self.args.decode_task:
            assert self.task_dim != 0
            task_decoder = TaskDecoder_dme(
                class_dim=self.args.vae_mixture_num,
                latent_dim=latent_dim,
                layers=self.args.task_decoder_layers,
                task_dim=self.task_dim,
                num_tasks=self.num_tasks,
                pred_type=self.args.task_pred_type,
            ).to(get_device())
        else:
            task_decoder = None

        return state_decoder, reward_decoder, task_decoder

    def log_normal(self, x, mu, var):
        """Logarithm of normal distribution with mean=mu and variance=var"""
        var = var + 1e-8
        return -0.5 * torch.sum(
            np.log(2.0 * np.pi) + torch.log(var) + torch.pow(x - mu, 2) / var, dim=-1)

    def kl_divergence_gaussians(self, mu1, logvar1, mu2, logvar2):
        """KL divergence between two Gaussians: KL(N(mu1,var1)||N(mu2,var2))"""
        var1 = torch.exp(logvar1)
        var2 = torch.exp(logvar2)
        
        kl = 0.5 * (logvar2 - logvar1 + var1 / var2 + ((mu1 - mu2) ** 2) / var2 - 1)
        return kl.sum(-1)  # sum over dimensions

    def occupancy_loss(self, y):
        """Occupancy loss to suppress usage of larger subtask index
            loss = logK * (1/f(K)) * -Σ (1,...,f(K)) dot y for linear
        Args:
            y: analytic subtask composition (soft probabilities)
        """
        if self.args.occ_loss_type == 'linear':
            occ_coeff = np.log(self.args.vae_mixture_num) * torch.arange(1.0, self.args.vae_mixture_num+1) / self.args.vae_mixture_num
        elif self.args.occ_loss_type == 'log':
            occ_coeff = torch.log(torch.arange(1.0, self.args.vae_mixture_num+1))
        elif self.args.occ_loss_type == 'exp':
            occ_coeff = np.log(self.args.vae_mixture_num) * torch.exp(torch.arange(1.0, self.args.vae_mixture_num+1)) / np.exp(self.args.vae_mixture_num)
        else:
            occ_coeff = torch.arange(1.0, self.args.vae_mixture_num+1)  # default linear

        occ_coeff = occ_coeff.to(get_device())
        occ_loss = occ_coeff * y

        return torch.mean(torch.sum(occ_loss, dim=-1))

    def _current_kl_cap(self):
        """Simple KL capacity for the discrete KL term"""
        return getattr(self.args, 'kl_y_cap', 0.0)

    def compute_state_reconstruction_loss(self, latent, prev_obs, next_obs, action, return_predictions=False):
        """ Compute state reconstruction loss. """
        state_pred = self.state_decoder(latent, prev_obs, action)

        if self.args.state_pred_type == 'deterministic':
            loss_state = (state_pred - next_obs).pow(2).mean(dim=-1)
        elif self.args.state_pred_type == 'gaussian':
            state_pred_mean = state_pred[:, :state_pred.shape[1] // 2]
            state_pred_std = torch.exp(0.5 * state_pred[:, state_pred.shape[1] // 2:])
            m = torch.distributions.normal.Normal(state_pred_mean, state_pred_std)
            loss_state = -m.log_prob(next_obs).mean(dim=-1)
        else:
            raise NotImplementedError

        if return_predictions:
            return loss_state, state_pred
        else:
            return loss_state

    def compute_rew_reconstruction_loss(self, latent, prev_obs, next_obs, action, reward, return_predictions=False):
        """ Compute reward reconstruction loss. """
        if self.args.multihead_for_reward:
            rew_pred = self.reward_decoder(latent, None)
            if self.args.rew_pred_type == 'categorical':
                rew_pred = F.softmax(rew_pred, dim=-1)
            elif self.args.rew_pred_type == 'bernoulli':
                rew_pred = torch.sigmoid(rew_pred)

            env = gym.make(self.args.env_name)
            state_indices = env.task_to_id(next_obs).to(get_device())
            if state_indices.dim() < rew_pred.dim():
                state_indices = state_indices.unsqueeze(-1)
            rew_pred = rew_pred.gather(dim=-1, index=state_indices)
            rew_target = (reward == 1).float()
            if self.args.rew_pred_type == 'deterministic':
                loss_rew = (rew_pred - reward).pow(2).mean(dim=-1)
            elif self.args.rew_pred_type in ['categorical', 'bernoulli']:
                loss_rew = F.binary_cross_entropy(rew_pred, rew_target, reduction='none').mean(dim=-1)
            else:
                raise NotImplementedError
        else:
            # GMVAE: simplified decoder call - no y argument
            rew_pred = self.reward_decoder(latent, next_obs, prev_obs, action.float())
            if self.args.rew_pred_type == 'bernoulli':
                rew_pred = torch.sigmoid(rew_pred)
                rew_target = (reward == 1).float()
                loss_rew = F.binary_cross_entropy(rew_pred, rew_target, reduction='none').mean(dim=-1)
            elif self.args.rew_pred_type == 'deterministic':
                loss_rew = (rew_pred - reward).pow(2).mean(dim=-1)
            elif self.args.rew_pred_type == 'optimistic':
                loss_rew = ((rew_pred - reward).pow(2) + 0.1 * (rew_pred - 10.0).pow(2)).mean(dim=-1)
            else:
                raise NotImplementedError

        if return_predictions:
            return loss_rew, rew_pred
        else:
            return loss_rew

    def compute_task_reconstruction_loss(self, latent, task, return_predictions=False):
        """ Compute task reconstruction loss. """
        task_pred = self.task_decoder(latent)

        if self.args.task_pred_type == 'task_id':
            env = gym.make(self.args.env_name)
            task_target = env.task_to_id(task).to(get_device())
            task_target = task_target.expand(task_pred.shape[:-1]).reshape(-1)
            loss_task = F.cross_entropy(task_pred.view(-1, task_pred.shape[-1]),
                                        task_target, reduction='none').view(task_pred.shape[:-1])
        elif self.args.task_pred_type == 'task_description':
            loss_task = (task_pred - task).pow(2).mean(dim=-1)
        else:
            raise NotImplementedError

        if return_predictions:
            return loss_task, task_pred
        else:
            return loss_task

    def compute_loss(self, latent_mean, latent_logvar, vae_prev_obs, vae_next_obs, vae_actions,
                     vae_rewards, vae_tasks, trajectory_lens, y, w, w_mu, w_logvar):
        """
        Computes the GMVAE loss following Dilokthanakul et al 2017.
        New ELBO: E[log p(x|z)] - KL(q(w|h)||p(w)) - E[log q(z|h) - log p(z|w,y)]
        
        Now uses correct w_mu, w_logvar from encoder instead of approximations.
        """
        num_unique_trajectory_lens = len(np.unique(trajectory_lens))
        assert (num_unique_trajectory_lens == 1) or (self.args.vae_subsample_elbos and self.args.vae_subsample_decodes)
        assert not self.args.decode_only_past

        # cut down the batch to the longest trajectory length
        max_traj_len = np.max(trajectory_lens)
        latent_mean = latent_mean[:max_traj_len + 1]
        latent_logvar = latent_logvar[:max_traj_len + 1]
        vae_prev_obs = vae_prev_obs[:max_traj_len]
        vae_next_obs = vae_next_obs[:max_traj_len]
        vae_actions = vae_actions[:max_traj_len]
        vae_rewards = vae_rewards[:max_traj_len]

        # Extract variables (truncate to match)
        w = w[:max_traj_len + 1] if w.dim() > 2 else w
        w_mu = w_mu[:max_traj_len + 1] if w_mu.dim() > 2 else w_mu
        w_logvar = w_logvar[:max_traj_len + 1] if w_logvar.dim() > 2 else w_logvar
        y = y[:max_traj_len + 1] if y.dim() > 2 else y

        # === COMPUTE mu_y, logvar_y FROM w USING ENCODER'S PRIOR NETWORK ===
        seq_len, batch_size = latent_mean.shape[:2]
        
        # For mu_y, logvar_y: compute from w using the encoder's prior network
        w_flat = w.view(-1, w.shape[-1])
        prior_params = self.encoder.prior_net_w2yz(w_flat)  # (N, 2*K*D)
        prior_params = prior_params.view(-1, 2, self.args.vae_mixture_num, w.shape[-1])  # (N, 2, K, D)
        
        mu_y_flat = prior_params[:, 0]  # (N, K, D)
        logvar_y_flat = prior_params[:, 1]  # (N, K, D)
        
        # Reshape back to sequence format
        mu_y = mu_y_flat.view(seq_len, batch_size, self.args.vae_mixture_num, -1)
        logvar_y = logvar_y_flat.view(seq_len, batch_size, self.args.vae_mixture_num, -1)

        # take one sample for each ELBO term
        if not self.args.disable_stochasticity_in_latent:
            latent_samples = self.encoder._sample_gaussian(latent_mean, latent_logvar)
        else:
            latent_samples = torch.cat((latent_mean, latent_logvar), dim=-1)

        num_elbos = latent_samples.shape[0]
        num_decodes = vae_prev_obs.shape[0]
        batchsize = latent_samples.shape[1]

        # subsample elbo terms
        if self.args.vae_subsample_elbos is not None:
            if num_unique_trajectory_lens == 1:
                elbo_indices = torch.LongTensor(self.args.vae_subsample_elbos * batchsize).random_(0, num_elbos)
            else:
                elbo_indices = np.concatenate([np.random.choice(range(0, t + 1), self.args.vae_subsample_elbos,
                                                                replace=self.args.vae_subsample_elbos > (t+1)) for t in trajectory_lens])
                if max_traj_len < self.args.vae_subsample_elbos:
                    warnings.warn('The required number of ELBOs is larger than the shortest trajectory.')
            task_indices = torch.arange(batchsize).repeat(self.args.vae_subsample_elbos)

            latent_samples = latent_samples[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))
            y = y[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))
            w_mu = w_mu[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))
            w_logvar = w_logvar[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))

            # Handle mu_y, logvar_y which have extra dimension for mixture components
            if mu_y.dim() == 4:  # (seq, batch, K, D)
                mu_y = mu_y[elbo_indices, task_indices, :, :].reshape((self.args.vae_subsample_elbos, batchsize, self.args.vae_mixture_num, -1))
                logvar_y = logvar_y[elbo_indices, task_indices, :, :].reshape((self.args.vae_subsample_elbos, batchsize, self.args.vae_mixture_num, -1))

            num_elbos = latent_samples.shape[0]
        else:
            elbo_indices = None

        # expand the state/rew/action inputs to the decoder
        dec_prev_obs = vae_prev_obs.unsqueeze(0).expand((num_elbos, *vae_prev_obs.shape))
        dec_next_obs = vae_next_obs.unsqueeze(0).expand((num_elbos, *vae_next_obs.shape))
        dec_actions = vae_actions.unsqueeze(0).expand((num_elbos, *vae_actions.shape))
        dec_rewards = vae_rewards.unsqueeze(0).expand((num_elbos, *vae_rewards.shape))

        # subsample reconstruction terms
        if self.args.vae_subsample_decodes is not None:
            indices0 = torch.arange(num_elbos).repeat(self.args.vae_subsample_decodes * batchsize)
            if num_unique_trajectory_lens == 1:
                indices1 = torch.LongTensor(num_elbos * self.args.vae_subsample_decodes * batchsize).random_(0, num_decodes)
            else:
                indices1 = np.concatenate([np.random.choice(range(0, t), num_elbos * self.args.vae_subsample_decodes,
                                                            replace=True) for t in trajectory_lens])
            indices2 = torch.arange(batchsize).repeat(num_elbos * self.args.vae_subsample_decodes)
            dec_prev_obs = dec_prev_obs[indices0, indices1, indices2, :].reshape((num_elbos, self.args.vae_subsample_decodes, batchsize, -1))
            dec_next_obs = dec_next_obs[indices0, indices1, indices2, :].reshape((num_elbos, self.args.vae_subsample_decodes, batchsize, -1))
            dec_actions = dec_actions[indices0, indices1, indices2, :].reshape((num_elbos, self.args.vae_subsample_decodes, batchsize, -1))
            dec_rewards = dec_rewards[indices0, indices1, indices2, :].reshape((num_elbos, self.args.vae_subsample_decodes, batchsize, -1))
            num_decodes = dec_prev_obs.shape[1]

        # expand the latent
        dec_embedding = latent_samples.unsqueeze(0).expand((num_decodes, *latent_samples.shape)).transpose(1, 0)

        # Reconstruction losses (GMVAE - no y argument)
        if self.args.decode_reward:
            rew_reconstruction_loss = self.compute_rew_reconstruction_loss(dec_embedding, dec_prev_obs, dec_next_obs,
                                                                           dec_actions, dec_rewards, False)  # No y argument
            if self.args.vae_avg_elbo_terms:
                rew_reconstruction_loss = rew_reconstruction_loss.mean(dim=0)
            else:
                rew_reconstruction_loss = rew_reconstruction_loss.sum(dim=0)
            if self.args.vae_avg_reconstruction_terms:
                rew_reconstruction_loss = rew_reconstruction_loss.mean(dim=0)
            else:
                rew_reconstruction_loss = rew_reconstruction_loss.sum(dim=0)
            rew_reconstruction_loss = rew_reconstruction_loss.mean()
        else:
            rew_reconstruction_loss = 0

        if self.args.decode_state:
            state_reconstruction_loss = self.compute_state_reconstruction_loss(dec_embedding, dec_prev_obs,
                                                                               dec_next_obs, dec_actions)
            if self.args.vae_avg_elbo_terms:
                state_reconstruction_loss = state_reconstruction_loss.mean(dim=0)
            else:
                state_reconstruction_loss = state_reconstruction_loss.sum(dim=0)
            if self.args.vae_avg_reconstruction_terms:
                state_reconstruction_loss = state_reconstruction_loss.mean(dim=0)
            else:
                state_reconstruction_loss = state_reconstruction_loss.sum(dim=0)
            state_reconstruction_loss = state_reconstruction_loss.mean()
        else:
            state_reconstruction_loss = 0

        if self.args.decode_task:
            task_reconstruction_loss = self.compute_task_reconstruction_loss(latent_samples, vae_tasks)
            if self.args.vae_avg_elbo_terms:
                task_reconstruction_loss = task_reconstruction_loss.mean(dim=0)
            else:
                task_reconstruction_loss = task_reconstruction_loss.sum(dim=0)
            task_reconstruction_loss = task_reconstruction_loss.sum(dim=0).mean()
        else:
            task_reconstruction_loss = 0

        # === NEW GMVAE LOSSES ===
        # KL(q(w|h) || p(w)) where p(w) = N(0,I) - NOW USING CORRECT w_mu, w_logvar
        w_mu_flat = w_mu.view(-1, w_mu.shape[-1])
        w_logvar_flat = w_logvar.view(-1, w_logvar.shape[-1])
        kl_w = 0.5 * (w_mu_flat.pow(2) + torch.exp(w_logvar_flat) - 1 - w_logvar_flat).sum(-1).mean()

        # KL(q(z|h) || p(z|w,y)) using VECTORIZED per-component formula:
        # kl_z = Σ_k y_k * KL( q(z|h) || N(μ_k(w), σ_k²(w)) )
        # ----- align time-axes -------------------------------------------------------
        E, B, K, D = mu_y.shape                        # e.g. (50, 2, 5, 5)
        latent_mean   = latent_mean[:E]                # trim to same E
        latent_logvar = latent_logvar[:E]

        # ----- flatten (E,B) → (N = E·B) --------------------------------------------
        z_mu_flat     = latent_mean.view(E * B, D)             # (N, D)
        z_logvar_flat = latent_logvar.view(E * B, D)           # (N, D)

        mu_y_flat     = mu_y.view(E * B, K, D)                 # (N, K, D)
        logvar_y_flat = logvar_y.view(E * B, K, D)             # (N, K, D)
        y_flat        = y.view(E * B, K)                       # (N, K)  soft probs

        # ----- broadcast z stats to every mixture component -------------------------
        z_mu_exp      = z_mu_flat.unsqueeze(1)                 # (N, 1, D)
        z_logvar_exp  = z_logvar_flat.unsqueeze(1)             # (N, 1, D)

        var_z = torch.exp(z_logvar_exp)                        # (N, 1, D) → b’cast to (N,K,D)
        var_y = torch.exp(logvar_y_flat)                       # (N, K, D)

        # ----- element-wise KL per component ----------------------------------------
        kl_per_component = 0.5 * (
            logvar_y_flat - z_logvar_exp +                     # log σ_y² − log σ_z²
            var_z / var_y +                                    # σ_z² / σ_y²
            (z_mu_exp - mu_y_flat).pow(2) / var_y - 1          # (μ_z−μ_y)² / σ_y² − 1
        ).sum(-1)                                              # (N, K)

        # ----- expectation over q(y) and batch mean ----------------------------------
        kl_z = (y_flat * kl_per_component).sum(-1).mean()      # scalar

        # Discrete KL loss for y (categorical posterior)
        probs_y = y_flat  # y is already the analytic posterior probabilities
        safe_log = lambda x: torch.log(torch.clamp(x, min=1e-8))
        raw_kl_y = (probs_y * (safe_log(probs_y) + np.log(self.args.vae_mixture_num))).sum(dim=-1)
        kl_cap = self._current_kl_cap()
        kl_y = torch.clamp(raw_kl_y - kl_cap, min=0.0).mean()

        # Categorical loss is now called kl_y (leaving this here for compatibility)
        loss_cat = torch.tensor(0.0).to(get_device())
        
        # Occupancy loss on analytic y
        if self.args.occ_loss_coeff != 0:
            y_for_occ = y.view(-1, y.shape[-1])  # flatten for occupancy calculation
            loss_occ = self.occupancy_loss(y_for_occ)
        else:
            loss_occ = torch.tensor(0.0).to(get_device())

        # Extrapolation loss (simplified - no longer used)
        loss_ext = torch.tensor(0.0).to(get_device())

        return rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, kl_w, kl_z, kl_y, loss_cat, loss_occ, loss_ext

    def compute_vae_loss(self, update=False, pretrain_index=None):
        """ Returns the VAE loss """

        if not self.rollout_storage.ready_for_update():
            return 0

        if self.args.disable_decoder and self.args.disable_kl_term:
            return 0

        # get a mini-batch
        vae_prev_obs, vae_next_obs, vae_actions, vae_rewards, vae_tasks, \
        trajectory_lens = self.rollout_storage.get_batch(batchsize=self.args.vae_batch_num_trajs)

        # pass through encoder - get w_mu, w_logvar for VAE internal use
        encoder_output = self.encoder(actions=vae_actions,
                                    states=vae_next_obs,
                                    rewards=vae_rewards,
                                    hidden_state=None,
                                    return_prior=True,
                                    return_w_params=True,  # Get w_mu, w_logvar for VAE
                                    detach_every=self.args.tbptt_stepsize if hasattr(self.args, 'tbptt_stepsize') else None,
                                    )

        # Unpack encoder output - 9 values for VAE internal use
        (latent_sample, latent_mean, latent_logvar, output, 
         y, prob, w, w_mu, w_logvar) = encoder_output
        
        if self.args.split_batches_by_task:
            raise NotImplementedError
        elif self.args.split_batches_by_elbo:
            raise NotImplementedError  # TODO: implement for GMVAE
        else:
            losses = self.compute_loss(latent_mean, latent_logvar, vae_prev_obs, vae_next_obs, vae_actions,
                                       vae_rewards, vae_tasks, trajectory_lens, y, w, w_mu, w_logvar)

        rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, kl_w, kl_z, kl_y, cat_loss, occ_loss, ext_loss = losses

        # take average (this is the expectation over p(M))
        loss = (self.args.rew_loss_coeff * rew_reconstruction_loss +
                self.args.state_loss_coeff * state_reconstruction_loss +
                self.args.task_loss_coeff * task_reconstruction_loss +
                self.args.kl_w_loss_coeff * kl_w +  # Separate coefficient for KL(q(w|h)||p(w))
                self.args.kl_z_loss_coeff * kl_z +  # Separate coefficient for KL(q(z|h)||p(z|w,y))
                self.args.kl_y_loss_coeff * kl_y +  # Discrete KL coefficient for KL(q(y|h)||p(y))
                self.args.cat_loss_coeff * cat_loss +
                self.args.occ_loss_coeff * occ_loss +
                self.args.ext_loss_coeff * ext_loss).mean()

        # make sure we can compute gradients
        if self.args.decode_reward:
            assert rew_reconstruction_loss.requires_grad
        if self.args.decode_state:
            assert state_reconstruction_loss.requires_grad
        if self.args.decode_task:
            assert task_reconstruction_loss.requires_grad
        if self.args.kl_w_loss_coeff != 0:
            assert kl_w.requires_grad
        if self.args.kl_z_loss_coeff != 0:
            assert kl_z.requires_grad
        if getattr(self.args, 'kl_y_loss_coeff', 0) != 0:
            assert kl_y.requires_grad

        # overall loss
        elbo_loss = loss.mean()

        if update:
            self.optimiser_vae.zero_grad()
            elbo_loss.backward()
            # clip gradients
            if self.args.encoder_max_grad_norm is not None:
                nn.utils.clip_grad_norm_(self.encoder.parameters(), self.args.encoder_max_grad_norm)
            if self.args.decoder_max_grad_norm is not None:
                if self.args.decode_reward:
                    nn.utils.clip_grad_norm_(self.reward_decoder.parameters(), self.args.decoder_max_grad_norm)
                if self.args.decode_state:
                    nn.utils.clip_grad_norm_(self.state_decoder.parameters(), self.args.decoder_max_grad_norm)
                if self.args.decode_task:
                    nn.utils.clip_grad_norm_(self.task_decoder.parameters(), self.args.decoder_max_grad_norm)
            # update
            self.optimiser_vae.step()

        self.log(elbo_loss, rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, kl_w, kl_z, kl_y, cat_loss, occ_loss, ext_loss, pretrain_index)

        return elbo_loss

    def log(self, elbo_loss, rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, kl_w, kl_z, kl_y, cat_loss, occ_loss, ext_loss, pretrain_index=None):

        if pretrain_index is None:
            curr_iter_idx = self.get_iter_idx()
        else:
            curr_iter_idx = - self.args.pretrain_len * self.args.num_vae_updates_per_pretrain + pretrain_index

        if curr_iter_idx % self.args.log_interval == 0:

            if self.args.decode_reward:
                self.logger.add('vae_losses/reward_reconstr_err', rew_reconstruction_loss.mean(), curr_iter_idx)
            if self.args.decode_state:
                self.logger.add('vae_losses/state_reconstr_err', state_reconstruction_loss.mean(), curr_iter_idx)
            if self.args.decode_task:
                self.logger.add('vae_losses/task_reconstr_err', task_reconstruction_loss.mean(), curr_iter_idx)

            if self.args.kl_w_loss_coeff != 0:
                self.logger.add('vae_losses/kl_w', kl_w.mean(), curr_iter_idx)
            if self.args.kl_z_loss_coeff != 0:
                self.logger.add('vae_losses/kl_z', kl_z.mean(), curr_iter_idx)
            if getattr(self.args, 'kl_y_loss_coeff', 0) != 0:
                self.logger.add('vae_losses/kl_y', kl_y.mean(), curr_iter_idx)
            self.logger.add('vae_losses/cat_loss', cat_loss.mean(), curr_iter_idx)
            if self.args.occ_loss_coeff != 0:
                self.logger.add('vae_losses/occ_loss', occ_loss.mean(), curr_iter_idx)
            self.logger.add('vae_losses/ext_loss', ext_loss.mean(), curr_iter_idx)
            self.logger.add('vae_losses/sum', elbo_loss, curr_iter_idx)

    # Dummy functions for compatibility - skill distillation not used
    def pretrain_with_skill_data(self, *args, **kwargs):
        """Dummy function for compatibility"""
        pass

    def set_skill_prior(self, *args, **kwargs):
        """Dummy function for compatibility"""
        pass

    def load_style_net(self, *args, **kwargs):
        """Dummy function for compatibility"""
        pass 