import warnings

import gym
import numpy as np
import torch
from torch.nn import functional as F
import torch.nn as nn
import time

from models.decoder_mixture_ext import StateTransitionDecoder_mixture_ext, RewardDecoder_mixture_ext, TaskDecoder_mixture_ext
from models.encoder_mixture_ext import RNNEncoder_mixture_ext
from utils.helpers import get_task_dim, get_num_tasks, get_device
from utils.storage_vae import RolloutStorageVAE

import matplotlib.pyplot as plt
from pathlib import Path
from utils.helpers import get_device
from torch.utils.data import DataLoader
from torch.distributions.kl import kl_divergence


class VaribadVAEMixtureExt:
    """
    GMVAE of SDVT:
    - has an encoder (Gaussian mixture) and decoder
    - can compute the ELBO loss (including categorical)
    - can update the GMVAE (encoder+decoder)
    - includes the dispersion layer
    """

    def __init__(self, args, logger, get_iter_idx, error_handling=False):

        self.args = args
        self.logger = logger
        self.get_iter_idx = get_iter_idx
        self.task_dim = get_task_dim(self.args) if self.args.decode_task else None
        self.num_tasks = get_num_tasks(self.args) if self.args.decode_task else None
        self.dummy = 0

        # initialise the encoder
        self.encoder = self.initialise_encoder()

        # initialise the decoders (returns None for unused decoders)
        self.state_decoder, self.reward_decoder, self.task_decoder = self.initialise_decoder()

        # initialise rollout storage for the VAE update
        # (this differs from the data that the on-policy RL algorithm uses)
        self.rollout_storage = RolloutStorageVAE(num_processes=self.args.num_processes,
                                                 max_trajectory_len=self.args.max_trajectory_len,
                                                 zero_pad=True,
                                                 max_num_rollouts=self.args.size_vae_buffer,
                                                 state_dim=self.args.state_dim,
                                                 action_dim=self.args.action_dim,
                                                 vae_buffer_add_thresh=self.args.vae_buffer_add_thresh,
                                                 task_dim=self.task_dim,
                                                 error_handling=error_handling
                                                 )

        self.rollout_storage_virtual = RolloutStorageVAE(num_processes=self.args.num_processes,
                                                 max_trajectory_len=self.args.max_trajectory_len,
                                                 zero_pad=True,
                                                 max_num_rollouts=0,
                                                 state_dim=self.args.state_dim,
                                                 action_dim=self.args.action_dim,
                                                 vae_buffer_add_thresh=self.args.vae_buffer_add_thresh,
                                                 task_dim=self.task_dim,
                                                 error_handling=error_handling
                                                 )

        # ------------------------------------------------------------
        #  Per-skill Gaussian priors  μ_y , σ²_y  (learned offline)
        # ------------------------------------------------------------
        self.mu_y    = nn.Parameter(torch.zeros(self.args.vae_mixture_num,
                                                self.args.latent_dim))
        self.sigma_y = nn.Parameter(torch.ones (self.args.vae_mixture_num,
                                                self.args.latent_dim))

        # Optional external “style” encoder (only loaded at test-time)
        self.style_net = None

        # initalise optimiser for the encoder and decoders
        decoder_params = []
        if not self.args.disable_decoder:
            if self.args.decode_reward:
                decoder_params.extend(self.reward_decoder.parameters())
            if self.args.decode_state:
                decoder_params.extend(self.state_decoder.parameters())
            if self.args.decode_task:
                decoder_params.extend(self.task_decoder.parameters())
        self.optimiser_vae = torch.optim.Adam([*self.encoder.parameters(), *decoder_params], lr=self.args.lr_vae)

    def initialise_encoder(self):
        """ Initialises and returns an RNN encoder """
        encoder = RNNEncoder_mixture_ext(
            args=self.args,
            layers_before_gru=self.args.encoder_layers_before_gru,
            hidden_size=self.args.encoder_gru_hidden_size,
            layers_after_gru=self.args.encoder_layers_after_gru,
            class_dim = self.args.vae_mixture_num,
            latent_dim=self.args.latent_dim,
            action_dim=self.args.action_dim,
            action_embed_dim=self.args.action_embedding_size,
            state_dim=self.args.state_dim,
            state_embed_dim=self.args.state_embedding_size,
            reward_size=1,
            reward_embed_size=self.args.reward_embedding_size,
        ).to(get_device())
        return encoder

    def initialise_decoder(self):
        """ Initialises and returns the (state/reward/task) decoder as specified in self.args """

        if self.args.disable_decoder:
            return None, None, None

        latent_dim = self.args.latent_dim
        # if we don't sample embeddings for the decoder, we feed in mean & variance
        if self.args.disable_stochasticity_in_latent:
            latent_dim *= 2

        # initialise state decoder for VAE
        if self.args.decode_state:
            state_decoder = StateTransitionDecoder_mixture_ext(
                args=self.args,
                layers=self.args.state_decoder_layers,
                class_dim = self.args.vae_mixture_num,
                latent_dim=latent_dim,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embedding_size,
                state_dim=self.args.state_dim,
                state_embed_dim=self.args.state_embedding_size,
                pred_type=self.args.state_pred_type,
                dropout_rate = self.args.dropout_rate
            ).to(get_device())
        else:
            state_decoder = None

        # initialise reward decoder for VAE
        if self.args.decode_reward:
            reward_decoder = RewardDecoder_mixture_ext(
                args=self.args,
                layers=self.args.reward_decoder_layers,
                class_dim=self.args.vae_mixture_num,
                latent_dim=latent_dim,
                state_dim=self.args.state_dim,
                state_embed_dim=self.args.state_embedding_size,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embedding_size,
                num_states=self.args.state_dim, #TODO: CHECK (but generally this is unused)
                multi_head=self.args.multihead_for_reward,
                pred_type=self.args.rew_pred_type,
                input_prev_state=self.args.input_prev_state,
                input_action=self.args.input_action,
                dropout_rate=self.args.dropout_rate
            ).to(get_device())
        else:
            reward_decoder = None

        # initialise task decoder for VAE
        if self.args.decode_task:
            assert self.task_dim != 0
            task_decoder = TaskDecoder_mixture_ext(
                class_dim=self.args.vae_mixture_num,
                latent_dim=latent_dim,
                layers=self.args.task_decoder_layers,
                task_dim=self.task_dim,
                num_tasks=self.num_tasks,
                pred_type=self.args.task_pred_type,
            ).to(get_device())
        else:
            task_decoder = None

        return state_decoder, reward_decoder, task_decoder

    def log_normal(self, x, mu, var): #TODO maybe the prior should be changed to previous steps as variBAD
        """Logarithm of normal distribution with mean=mu and variance=var
           log(x|μ, σ^2) = loss = -0.5 * Σ log(2π) + log(σ^2) + ((x - μ)/σ)^2
        Args:
           x: (array) corresponding array containing the input
           mu: (array) corresponding array containing the mean
           var: (array) corresponding array containing the variance
        Returns:
           output: (array/float) depending on average parameters the result will be the mean
                                  of all the sample losses or an array with the losses per sample
        """
        var = var + 1e-8
        return -0.5 * torch.sum(
            np.log(2.0 * np.pi) + torch.log(var) + torch.pow(x - mu, 2) / var, dim=-1)

    def gaussian_loss(self, z, z_mu, z_var, z_mu_prior, z_var_prior):
        """Variational loss when using labeled data without considering reconstruction loss
           loss = log q(z|x,y) - log p(z) - log p(y)
        Args:
           z: (array) array containing the gaussian latent variable
           z_mu: (array) array containing the mean of the inference model
           z_var: (array) array containing the variance of the inference model
           z_mu_prior: (array) array containing the prior mean of the generative model
           z_var_prior: (array) array containing the prior variance of the generative mode

        Returns:
           output: (array/float) depending on average parameters the result will be the mean
                                  of all the sample losses or an array with the losses per sample
        """
        loss = self.log_normal(z, z_mu, z_var) - self.log_normal(z, z_mu_prior, z_var_prior)
        return loss.mean()

    def entropy(self, logits, targets):
        """Entropy loss
            loss = (1/n) * -Σ targets*log(predicted)
        Args:
            logits: (array) corresponding array containing the logits of the categorical variable
            real: (array) corresponding array containing the true labels

        Returns:
            output: (array/float) depending on average parameters the result will be the mean
                                  of all the sample losses or an array with the losses per sample
        """
        log_q = F.log_softmax(logits, dim=-1)
        return -torch.mean(torch.sum(targets * log_q, dim=-1))

    def occupancy_loss(self, y):
        """occupancy loss to suppress usage of larger subtask index
            loss = logK * (1/f(K)) * -Σ (1,...,f(K)) dot y for linear
        Args:
            y: sampled subtask composition
        args.type: How f(K) is defined, linear square, or exponential
        maximum is set as log(K) to match magnitude of the entropy loss
        """
        if self.args.occ_loss_type == 'linear':
            occ_coeff = np.log(self.args.vae_mixture_num) * torch.arange(1.0,self.args.vae_mixture_num+1)/self.args.vae_mixture_num
        elif self.args.occ_loss_type == 'log':
            occ_coeff = torch.log(torch.arange(1.0,self.args.vae_mixture_num+1))
        elif self.args.occ_loss_type == 'exp':
            occ_coeff = np.log(self.args.vae_mixture_num) * torch.exp(torch.arange(1.0,self.args.vae_mixture_num+1))/np.exp(self.args.vae_mixture_num)

        occ_loss = occ_coeff.to(get_device()) * y

        return torch.mean(torch.sum(occ_loss, dim=-1))

    def compute_state_reconstruction_loss(self, latent, prev_obs, next_obs, action, return_predictions=False):
        """ Compute state reconstruction loss.
        (No reduction of loss along batch dimension is done here; sum/avg has to be done outside) """

        state_pred = self.state_decoder(latent, prev_obs, action)

        if self.args.state_pred_type == 'deterministic':
            loss_state = (state_pred - next_obs).pow(2).mean(dim=-1)
        elif self.args.state_pred_type == 'gaussian':  # TODO: untested!
            state_pred_mean = state_pred[:, :state_pred.shape[1] // 2]
            state_pred_std = torch.exp(0.5 * state_pred[:, state_pred.shape[1] // 2:])
            m = torch.distributions.normal.Normal(state_pred_mean, state_pred_std)
            loss_state = -m.log_prob(next_obs).mean(dim=-1)
        else:
            raise NotImplementedError

        if return_predictions:
            return loss_state, state_pred
        else:
            return loss_state

    def compute_rew_reconstruction_loss(self, latent, prev_obs, next_obs, action, reward, return_predictions=False, y=None):
        """ Compute reward reconstruction loss.
        (No reduction of loss along batch dimension is done here; sum/avg has to be done outside) """

        if self.args.multihead_for_reward: #not tested for mixture yet
            rew_pred = self.reward_decoder(latent, None)
            if self.args.rew_pred_type == 'categorical':
                rew_pred = F.softmax(rew_pred, dim=-1)
            elif self.args.rew_pred_type == 'bernoulli':
                rew_pred = torch.sigmoid(rew_pred)

            env = gym.make(self.args.env_name)
            state_indices = env.task_to_id(next_obs).to(get_device())
            if state_indices.dim() < rew_pred.dim():
                state_indices = state_indices.unsqueeze(-1)
            rew_pred = rew_pred.gather(dim=-1, index=state_indices)
            rew_target = (reward == 1).float()
            if self.args.rew_pred_type == 'deterministic':  # TODO: untested!
                loss_rew = (rew_pred - reward).pow(2).mean(dim=-1)
            elif self.args.rew_pred_type in ['categorical', 'bernoulli']:
                loss_rew = F.binary_cross_entropy(rew_pred, rew_target, reduction='none').mean(dim=-1)
            else:
                raise NotImplementedError
        else:
            rew_pred, y_mu, y_var, h_hat = self.reward_decoder(latent, next_obs, prev_obs, action.float(), y)
            if self.args.rew_pred_type == 'bernoulli':  # TODO: untested!
                rew_pred = torch.sigmoid(rew_pred)
                rew_target = (reward == 1).float()  # TODO: necessary?
                loss_rew = F.binary_cross_entropy(rew_pred, rew_target, reduction='none').mean(dim=-1)
            elif self.args.rew_pred_type == 'deterministic': #ONLY THIS IS DONE FOR NOW 220810
                loss_rew = (rew_pred - reward).pow(2).mean(dim=-1)
            elif self.args.rew_pred_type == 'optimistic':
                loss_rew = ((rew_pred - reward).pow(2) + 0.1 * (rew_pred - 10.0).pow(2)).mean(dim=-1) #predict reward for unseen states larger, no overfit 0 reward
            else:
                raise NotImplementedError

        if return_predictions:
            return loss_rew, rew_pred, y_mu, y_var, h_hat
        else:
            return loss_rew, y_mu, y_var, h_hat

    def compute_task_reconstruction_loss(self, latent, task, return_predictions=False):
        """ Compute task reconstruction loss.
        (No reduction of loss along batch dimension is done here; sum/avg has to be done outside) """

        task_pred = self.task_decoder(latent)

        if self.args.task_pred_type == 'task_id':
            env = gym.make(self.args.env_name)
            task_target = env.task_to_id(task).to(get_device())
            # expand along first axis (number of ELBO terms)
            task_target = task_target.expand(task_pred.shape[:-1]).reshape(-1)
            loss_task = F.cross_entropy(task_pred.view(-1, task_pred.shape[-1]),
                                        task_target, reduction='none').view(task_pred.shape[:-1])
        elif self.args.task_pred_type == 'task_description':
            loss_task = (task_pred - task).pow(2).mean(dim=-1)
        else:
            raise NotImplementedError

        if return_predictions:
            return loss_task, task_pred
        else:
            return loss_task

    def compute_loss(self, latent_mean, latent_logvar, vae_prev_obs, vae_next_obs, vae_actions,
                     vae_rewards, vae_tasks, trajectory_lens, y, z, mu, var, logits, prob, output):
        """
        Computes the VAE loss for the given data.
        Batches everything together and therefore needs all trajectories to be of the same length.
        (Important because we need to separate ELBOs and decoding terms so can't collapse those dimensions)
        """

        num_unique_trajectory_lens = len(np.unique(trajectory_lens))

        assert (num_unique_trajectory_lens == 1) or (self.args.vae_subsample_elbos and self.args.vae_subsample_decodes)
        assert not self.args.decode_only_past

        # cut down the batch to the longest trajectory length
        # this way we can preserve the structure
        # but we will waste some computation on zero-padded trajectories that are shorter than max_traj_len
        max_traj_len = np.max(trajectory_lens)
        latent_mean = latent_mean[:max_traj_len + 1]
        latent_logvar = latent_logvar[:max_traj_len + 1]
        vae_prev_obs = vae_prev_obs[:max_traj_len]
        vae_next_obs = vae_next_obs[:max_traj_len]
        vae_actions = vae_actions[:max_traj_len]
        vae_rewards = vae_rewards[:max_traj_len]

        # take one sample for each ELBO term
        if not self.args.disable_stochasticity_in_latent:
            latent_samples = self.encoder._sample_gaussian(latent_mean, latent_logvar)
        else:
            latent_samples = torch.cat((latent_mean, latent_logvar), dim=-1)


        num_elbos = latent_samples.shape[0]
        num_decodes = vae_prev_obs.shape[0]
        batchsize = latent_samples.shape[1]  # number of trajectories


        # subsample elbo terms
        #   shape before: num_elbos * batchsize * dim
        #   shape after: vae_subsample_elbos * batchsize * dim
        if self.args.vae_subsample_elbos is not None:
            # randomly choose which elbo's to subsample
            if num_unique_trajectory_lens == 1:
                elbo_indices = torch.LongTensor(self.args.vae_subsample_elbos * batchsize).random_(0, num_elbos)    # select diff elbos for each task
            else:
                # if we have different trajectory lengths, subsample elbo indices separately
                # up to their maximum possible encoding length;
                # only allow duplicates if the sample size would be larger than the number of samples
                elbo_indices = np.concatenate([np.random.choice(range(0, t + 1), self.args.vae_subsample_elbos,
                                                                replace=self.args.vae_subsample_elbos > (t+1)) for t in trajectory_lens])
                if max_traj_len < self.args.vae_subsample_elbos:
                    warnings.warn('The required number of ELBOs is larger than the shortest trajectory, '
                                  'so there will be duplicates in your batch.'
                                  'To avoid this use --split_batches_by_elbo or --split_batches_by_task.')
            task_indices = torch.arange(batchsize).repeat(self.args.vae_subsample_elbos)  # for selection mask

            latent_samples = latent_samples[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))
            y = y[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))
            z = z[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))
            mu = mu[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))
            var = var[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))
            logits = logits[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))
            prob = prob[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))
            output = output[elbo_indices, task_indices, :].reshape((self.args.vae_subsample_elbos, batchsize, -1))

            num_elbos = latent_samples.shape[0]
        else:
            elbo_indices = None


        # expand the state/rew/action inputs to the decoder (to match size of latents)
        # shape will be: [num tasks in batch] x [num elbos] x [len trajectory (reconstrution loss)] x [dimension]
        dec_prev_obs = vae_prev_obs.unsqueeze(0).expand((num_elbos, *vae_prev_obs.shape)) #50 5000 10 40
        dec_next_obs = vae_next_obs.unsqueeze(0).expand((num_elbos, *vae_next_obs.shape)) #50 5000 10 40
        dec_actions = vae_actions.unsqueeze(0).expand((num_elbos, *vae_actions.shape)) #50 5000 10 4
        dec_rewards = vae_rewards.unsqueeze(0).expand((num_elbos, *vae_rewards.shape)) #50 5000 10 1

        # subsample reconstruction terms
        if self.args.vae_subsample_decodes is not None:
            # shape before: vae_subsample_elbos * num_decodes * batchsize * dim
            # shape after: vae_subsample_elbos * vae_subsample_decodes * batchsize * dim
            # (Note that this will always have duplicates given how we set up the code)
            indices0 = torch.arange(num_elbos).repeat(self.args.vae_subsample_decodes * batchsize)
            if num_unique_trajectory_lens == 1:
                indices1 = torch.LongTensor(num_elbos * self.args.vae_subsample_decodes * batchsize).random_(0, num_decodes)
            else:
                indices1 = np.concatenate([np.random.choice(range(0, t), num_elbos * self.args.vae_subsample_decodes,
                                                            replace=True) for t in trajectory_lens])
            indices2 = torch.arange(batchsize).repeat(num_elbos * self.args.vae_subsample_decodes)
            dec_prev_obs = dec_prev_obs[indices0, indices1, indices2, :].reshape((num_elbos, self.args.vae_subsample_decodes, batchsize, -1)) #50 50 10 40
            dec_next_obs = dec_next_obs[indices0, indices1, indices2, :].reshape((num_elbos, self.args.vae_subsample_decodes, batchsize, -1)) #50 50 10 40
            dec_actions = dec_actions[indices0, indices1, indices2, :].reshape((num_elbos, self.args.vae_subsample_decodes, batchsize, -1)) #50 50 10 4
            dec_rewards = dec_rewards[indices0, indices1, indices2, :].reshape((num_elbos, self.args.vae_subsample_decodes, batchsize, -1)) #50 50 10 1
            num_decodes = dec_prev_obs.shape[1]

        # expand the latent (to match the number of state/rew/action inputs to the decoder)
        # shape will be: [num tasks in batch] x [num elbos] x [len trajectory (reconstrution loss)] x [dimension]
        dec_embedding = latent_samples.unsqueeze(0).expand((num_decodes, *latent_samples.shape)).transpose(1, 0)
        output_embedding = output.unsqueeze(0).expand((num_decodes, *output.shape)).transpose(1, 0)

        if self.args.decode_reward:
            # compute reconstruction loss for this trajectory (for each timestep that was encoded, decode everything and sum it up)
            # shape: [num_elbo_terms] x [num_reconstruction_terms] x [num_trajectories]
            rew_reconstruction_loss, y_mu, y_var, h_hat = self.compute_rew_reconstruction_loss(dec_embedding, dec_prev_obs, dec_next_obs,
                                                                           dec_actions, dec_rewards,False, y) #TODO: subsample y as well

            # avg/sum across individual ELBO terms
            if self.args.vae_avg_elbo_terms:
                rew_reconstruction_loss = rew_reconstruction_loss.mean(dim=0)
            else:
                rew_reconstruction_loss = rew_reconstruction_loss.sum(dim=0)
            # avg/sum across individual reconstruction terms
            if self.args.vae_avg_reconstruction_terms:
                rew_reconstruction_loss = rew_reconstruction_loss.mean(dim=0)
            else:
                rew_reconstruction_loss = rew_reconstruction_loss.sum(dim=0)
            # average across tasks
            rew_reconstruction_loss = rew_reconstruction_loss.mean()
        else:
            rew_reconstruction_loss = 0

        if self.args.decode_state:
            state_reconstruction_loss = self.compute_state_reconstruction_loss(dec_embedding, dec_prev_obs,
                                                                               dec_next_obs, dec_actions)
            # avg/sum across individual ELBO terms
            if self.args.vae_avg_elbo_terms:
                state_reconstruction_loss = state_reconstruction_loss.mean(dim=0)
            else:
                state_reconstruction_loss = state_reconstruction_loss.sum(dim=0)
            # avg/sum across individual reconstruction terms
            if self.args.vae_avg_reconstruction_terms:
                state_reconstruction_loss = state_reconstruction_loss.mean(dim=0)
            else:
                state_reconstruction_loss = state_reconstruction_loss.sum(dim=0)
            # average across tasks
            state_reconstruction_loss = state_reconstruction_loss.mean()
        else:
            state_reconstruction_loss = 0

        if self.args.decode_task:
            task_reconstruction_loss = self.compute_task_reconstruction_loss(latent_samples, vae_tasks)
            # avg/sum across individual ELBO terms
            if self.args.vae_avg_elbo_terms:
                task_reconstruction_loss = task_reconstruction_loss.mean(dim=0)
            else:
                task_reconstruction_loss = task_reconstruction_loss.sum(dim=0)
            # sum the elbos, average across tasks
            task_reconstruction_loss = task_reconstruction_loss.sum(dim=0).mean()
        else:
            task_reconstruction_loss = 0

        loss_gauss = self.gaussian_loss(z, mu, var, y_mu, y_var)
        loss_cat = -self.entropy(logits, prob) - np.log(1.0/self.args.vae_mixture_num) #uniform entropy
        if self.args.occ_loss_coeff!=0:
            loss_occ = self.occupancy_loss(y)
        else:
            loss_occ = 0

        loss_ext = (output_embedding.detach()-h_hat).pow(2).mean() #extrapolation loss from DiaMetR

        return rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, loss_gauss, loss_cat, loss_occ, loss_ext

    def compute_loss_split_batches_by_elbo(self, latent_mean, latent_logvar, vae_prev_obs, vae_next_obs, vae_actions,
                                           vae_rewards, vae_tasks, trajectory_lens):

        """
        Loop over the elvo_t terms to compute losses per t.
        Saves some memory if batch sizes are very large,
        or if trajectory lengths are different, or if we decode only the past.
        """

        rew_reconstruction_loss = []
        state_reconstruction_loss = []
        task_reconstruction_loss = []

        assert len(np.unique(trajectory_lens)) == 1
        n_horizon = np.unique(trajectory_lens)[0]
        n_elbos = latent_mean.shape[0]  # includes the prior

        # for each elbo term (including one for the prior)...
        for idx_elbo in range(n_elbos):

            # get the embedding values (size: traj_length+1 * latent_dim; the +1 is for the prior)
            curr_means = latent_mean[idx_elbo]
            curr_logvars = latent_logvar[idx_elbo]

            # take one sample for each task
            if not self.args.disable_stochasticity_in_latent:
                curr_samples = self.encoder._sample_gaussian(curr_means, curr_logvars)
            else:
                curr_samples = torch.cat((latent_mean, latent_logvar))

            # if the size of what we decode is always the same, we can speed up creating the batches
            if not self.args.decode_only_past:

                # expand the latent to match the (x, y) pairs of the decoder
                dec_embedding = curr_samples.unsqueeze(0).expand((n_horizon, *curr_samples.shape))
                dec_embedding_task = curr_samples

                dec_prev_obs = vae_prev_obs
                dec_next_obs = vae_next_obs
                dec_actions = vae_actions
                dec_rewards = vae_rewards

            # otherwise, we unfortunately have to loop!
            # loop through the lengths we are feeding into the encoder for that trajectory (starting with prior)
            # (these are the different ELBO_t terms)
            else:

                # get the index until which we want to decode
                # (i.e. eithe runtil curr timestep or entire trajectory including future)
                if self.args.decode_only_past:
                    dec_from = 0
                    dec_until = idx_elbo
                else:
                    dec_from = 0
                    dec_until = n_horizon

                if dec_from == dec_until:
                    continue

                # (1) ... get the latent sample after feeding in some data (determined by len_encoder) & expand (to number of outputs)
                # num latent samples x embedding size
                dec_embedding = curr_samples.unsqueeze(0).expand(dec_until - dec_from, *curr_samples.shape)
                dec_embedding_task = curr_samples
                # (2) ... get the predictions for the trajectory until the timestep we're interested in
                dec_prev_obs = vae_prev_obs[dec_from:dec_until]
                dec_next_obs = vae_next_obs[dec_from:dec_until]
                dec_actions = vae_actions[dec_from:dec_until]
                dec_rewards = vae_rewards[dec_from:dec_until]

            if self.args.decode_reward:
                # compute reconstruction loss for this trajectory (for each timestep that was encoded, decode everything and sum it up)
                # size: if all trajectories are of same length [num_elbo_terms x num_reconstruction_terms], otherwise it's flattened into one
                rrc = self.compute_rew_reconstruction_loss(dec_embedding, dec_prev_obs, dec_next_obs, dec_actions, dec_rewards)
                # sum up the reconstruction terms; average over tasks
                rrc = rrc.sum(dim=0).mean()
                rew_reconstruction_loss.append(rrc)

            if self.args.decode_state:
                src = self.compute_state_reconstruction_loss(dec_embedding, dec_prev_obs, dec_next_obs, dec_actions)
                # sum up the reconstruction terms; average over tasks
                src = src.sum(dim=0).mean()
                state_reconstruction_loss.append(src)

            if self.args.decode_task:
                trc = self.compute_task_reconstruction_loss(dec_embedding_task, vae_tasks)
                # average across tasks
                trc = trc.mean()
                task_reconstruction_loss.append(trc)

        # sum the ELBO_t terms
        if self.args.decode_reward:
            rew_reconstruction_loss = torch.stack(rew_reconstruction_loss)
            rew_reconstruction_loss = rew_reconstruction_loss.sum()
        else:
            rew_reconstruction_loss = 0

        if self.args.decode_state:
            state_reconstruction_loss = torch.stack(state_reconstruction_loss)
            state_reconstruction_loss = state_reconstruction_loss.sum()
        else:
            state_reconstruction_loss = 0

        if self.args.decode_task:
            task_reconstruction_loss = torch.stack(task_reconstruction_loss)
            task_reconstruction_loss = task_reconstruction_loss.sum()
        else:
            task_reconstruction_loss = 0

        return rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss

    def compute_vae_loss(self, update=False, pretrain_index=None):
        """ Returns the VAE loss """

        if not self.rollout_storage.ready_for_update():
            return 0

        if self.args.disable_decoder and self.args.disable_kl_term:
            return 0

        # get a mini-batch
        #print('vae get batch called')
        vae_prev_obs, vae_next_obs, vae_actions, vae_rewards, vae_tasks, \
        trajectory_lens = self.rollout_storage.get_batch(batchsize=self.args.vae_batch_num_trajs)
        # vae_prev_obs will be of size: max trajectory len x num trajectories x dimension of observations

        # pass through encoder (outputs will be: (max_traj_len+1) x number of rollouts x latent_dim -- includes the prior!)
        latent_sample, latent_mean, latent_logvar, output, \
        y, z, mu, var, logits, prob= self.encoder(actions=vae_actions,
                                                        states=vae_next_obs,
                                                        rewards=vae_rewards,
                                                        hidden_state=None,
                                                        return_prior=True,
                                                        detach_every=self.args.tbptt_stepsize if hasattr(self.args, 'tbptt_stepsize') else None,
                                                        )
        if self.args.split_batches_by_task:
            raise NotImplementedError
            losses = self.compute_loss_split_batches_by_task(latent_mean, latent_logvar, vae_prev_obs, vae_next_obs,
                                                             vae_actions, vae_rewards, vae_tasks,
                                                             trajectory_lens, len_encoder)
        elif self.args.split_batches_by_elbo:
            losses = self.compute_loss_split_batches_by_elbo(latent_mean, latent_logvar, vae_prev_obs, vae_next_obs,
                                                             vae_actions, vae_rewards, vae_tasks,
                                                             trajectory_lens)
        else:
            losses = self.compute_loss(latent_mean, latent_logvar, vae_prev_obs, vae_next_obs, vae_actions,
                                       vae_rewards, vae_tasks, trajectory_lens, y, z, mu, var, logits, prob, output)
        rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, gauss_loss, cat_loss, occ_loss, ext_loss = losses

        # take average (this is the expectation over p(M))
        loss = (self.args.rew_loss_coeff * rew_reconstruction_loss +
                self.args.state_loss_coeff * state_reconstruction_loss +
                self.args.task_loss_coeff * task_reconstruction_loss +
                self.args.gauss_loss_coeff * gauss_loss +
                self.args.cat_loss_coeff * cat_loss +
                self.args.occ_loss_coeff * occ_loss +
                self.args.ext_loss_coeff * ext_loss).mean()

        # make sure we can compute gradients
        if self.args.decode_reward:
            assert rew_reconstruction_loss.requires_grad
        if self.args.decode_state:
            assert state_reconstruction_loss.requires_grad
        if self.args.decode_task:
            assert task_reconstruction_loss.requires_grad
        assert gauss_loss.requires_grad
        assert cat_loss.requires_grad
        if self.args.occ_loss_coeff != 0:
            assert occ_loss.requires_grad
        assert ext_loss.requires_grad

        # overall loss
        elbo_loss = loss.mean()

        if update:
            self.optimiser_vae.zero_grad()
            elbo_loss.backward()
            # clip gradients
            if self.args.encoder_max_grad_norm is not None:
                nn.utils.clip_grad_norm_(self.encoder.parameters(), self.args.encoder_max_grad_norm)
            if self.args.decoder_max_grad_norm is not None:
                if self.args.decode_reward:
                    nn.utils.clip_grad_norm_(self.reward_decoder.parameters(), self.args.decoder_max_grad_norm)
                if self.args.decode_state:
                    nn.utils.clip_grad_norm_(self.state_decoder.parameters(), self.args.decoder_max_grad_norm)
                if self.args.decode_task:
                    nn.utils.clip_grad_norm_(self.task_decoder.parameters(), self.args.decoder_max_grad_norm)
            # update
            self.optimiser_vae.step()

        self.log(elbo_loss, rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, gauss_loss, cat_loss, occ_loss, ext_loss, pretrain_index)


        return elbo_loss

    def log(self, elbo_loss, rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, gauss_loss, cat_loss, occ_loss, ext_loss, pretrain_index=None):

        if pretrain_index is None:
            curr_iter_idx = self.get_iter_idx()
        else:
            curr_iter_idx = - self.args.pretrain_len * self.args.num_vae_updates_per_pretrain + pretrain_index

        if curr_iter_idx % self.args.log_interval == 0:

            if self.args.decode_reward:
                self.logger.add('vae_losses/reward_reconstr_err', rew_reconstruction_loss.mean(), curr_iter_idx)
            if self.args.decode_state:
                self.logger.add('vae_losses/state_reconstr_err', state_reconstruction_loss.mean(), curr_iter_idx)
            if self.args.decode_task:
                self.logger.add('vae_losses/task_reconstr_err', task_reconstruction_loss.mean(), curr_iter_idx)

            self.logger.add('vae_losses/gauss_loss', gauss_loss.mean(), curr_iter_idx)
            self.logger.add('vae_losses/cat_loss', cat_loss.mean(), curr_iter_idx)
            if self.args.occ_loss_coeff != 0:
                self.logger.add('vae_losses/occ_loss', occ_loss.mean(), curr_iter_idx)
            self.logger.add('vae_losses/ext_loss', ext_loss.mean(), curr_iter_idx)
            self.logger.add('vae_losses/sum', elbo_loss, curr_iter_idx)

    ## OFFLINE skill-aware decoder distillation        
    def pretrain_with_skill_data(
        self,
        skill_buffer,
        epochs: int = 5000,
        mixup_alpha: float = 0.4,
    ):
        """
        Run once by `pretrainer_SDVT.py`.
        Updates  (state|reward|task) decoders plus  μ_y, σ_y².
        """
        loader = DataLoader(
            skill_buffer,
            batch_size=self.args.vae_batch_num_trajs,
            shuffle=True,
            drop_last=True,
        )

        opt = torch.optim.Adam(
            [
                *(self.state_decoder.parameters() if self.state_decoder else []),
                *(self.reward_decoder.parameters() if self.reward_decoder else []),
                *(self.task_decoder.parameters()  if self.task_decoder  else []),
                self.mu_y,
                self.sigma_y,
            ],
            lr=self.args.lr_vae,
        )

        beta = 1.0
        global_step = 0

        for epoch in range(epochs):
            for prev_s, act, next_s, r_int, skill, _, _ in loader:
                global_step += 1
                prev_s, act, next_s, r_int, skill = [
                    t.to(get_device()) for t in (prev_s, act, next_s, r_int, skill)
                ]

                # MixUp on categorical skill
                if mixup_alpha > 0:
                    lam = torch.distributions.Beta(
                        mixup_alpha, mixup_alpha
                    ).sample().to(get_device())
                    perm = torch.randperm(prev_s.size(0), device=get_device())
                    skill = lam * skill + (1 - lam) * skill[perm]
                    r_int = lam * r_int + (1 - lam) * r_int[perm]

                # --- Sample new ε ~ N(μ_y, σ_y) per skill index ---
                idx = skill.argmax(dim=1)  # which mixture‐component each example belongs to
                mu    = self.mu_y[idx].to(get_device())      # (batch_size, latent_dim)
                var   = self.sigma_y[idx].to(get_device())   # (batch_size, latent_dim)
                eps   = mu + torch.randn_like(mu).to(get_device()) * var.sqrt()

                # recon losses
                nll_state = torch.tensor(0.0, device=get_device())
                nll_rew   = torch.tensor(0.0, device=get_device())

                if self.state_decoder is not None:
                    state_pred = self.state_decoder(eps, prev_s, act)
                    nll_state = (state_pred - next_s).pow(2).mean()

                if self.reward_decoder is not None:
                    rew_pred, _, _, _ = self.reward_decoder(
                        eps, next_s, prev_s, act.float(), skill
                    )
                    nll_rew = (rew_pred - r_int).pow(2).mean()

                kl = 0.5 * (var + mu.pow(2) - 1 - var.log()).mean()
                # --- Total VAE loss (for this batch) ---
                total_loss = (nll_state
                              + nll_rew
                              + beta * kl).mean()

                opt.zero_grad()
                total_loss.backward()
                opt.step()

            # TensorBoard scalars
            if self.logger is not None:
                self.logger.add("vae_losses/recon_state", nll_state.item(), global_step)
                self.logger.add("vae_losses/recon_reward", nll_rew.item(), global_step)
                self.logger.add("vae_losses/kl", kl.item(), global_step)
                self.logger.add("vae_losses/total", total_loss.item(), global_step)

                # Also log parameters of the skill priors
                mu_cpu    = self.mu_y.detach().cpu().numpy()    # shape: (K, D)
                sigma_cpu = self.sigma_y.detach().cpu().numpy() # shape: (K, D)

                # 2. For convenience, compute per‐component mean/variance across latent‐dimensions:
                #    - mu_mean[k] = mean over dimension for skill k
                #    - mu_var[k]  = variance across dimension for skill k
                mu_mean = mu_cpu.mean(axis=1)
                mu_var  = mu_cpu.var(axis=1)
                sigma_mean = sigma_cpu.mean(axis=1)
                sigma_var  = sigma_cpu.var(axis=1)

                # 3. Log means/vars as scalars
                for k in range(self.args.vae_mixture_num):
                    self.logger.add(f'prior/mu_mean_skill_{k}', mu_mean[k], epoch)
                    self.logger.add(f'prior/mu_var_skill_{k}',  mu_var[k],  epoch)
                    self.logger.add(f'prior/sigma_mean_skill_{k}', sigma_mean[k], epoch)
                    self.logger.add(f'prior/sigma_var_skill_{k}',  sigma_var[k],  epoch)


            beta = max(0.0, beta - 1.0 / epochs)  # linear KL anneal

            # -----------------------------
            #  Log to console, then plot
            # -----------------------------
            current_time = time.time()
            current_time_str = time.strftime("%H:%M:%S")
            if hasattr(self, 'last_log_time'):
                time_since_last = int(current_time - self.last_log_time)
                time_msg = f" (+{time_since_last}s)"
            else:
                time_msg = ""
            self.last_log_time = current_time

            
            print(f"[VAE] {current_time_str} | epoch {epoch}/{epochs}{time_msg}: "
                    f"last_total_loss={total_loss.item():.4f}, "
                    f"last_nll_state={nll_state.item():.4f}, "
                    f"last_nll_rew={nll_rew.item():.4f}, last_kl={kl.item():.4f}"
                    )

            if global_step % self.args.vis_interval == 0:
                fig, axs = plt.subplots(2, 1, figsize=(5, 8), tight_layout=True)
                axs[0].bar(range(len(mu_mean)), mu_mean)
                axs[0].set_title("skill‐priors: μ_y mean per component")
                axs[0].set_xlabel("skill index k")
                axs[0].set_ylabel("μ_mean")
                axs[1].bar(range(len(sigma_mean)), sigma_mean, color='C1')
                axs[1].set_title("skill‐priors: σ_y mean per component")
                axs[1].set_xlabel("skill index k")
                axs[1].set_ylabel("σ_mean")

                plot_path = Path(self.logger.full_output_folder) / f"prior_trends_epoch_{epoch}.png"
                fig.savefig(str(plot_path))
                plt.close(fig) 

    # Called by metalearner to load priors from disk
    def set_skill_prior(self, prior: dict):
        """Load μ_y and σ_y² from distillation checkpoint."""
        self.mu_y.data.copy_(prior["mu"].to(get_device()))
        self.sigma_y.data.copy_(prior["sigma"].to(get_device()))

    def load_style_net(self, path: str):
        if self.style_net is not None:
            self.style_net.load_state_dict(torch.load(path, map_location=get_device()))
