import os
import time

import gym
import numpy as np
import torch

from algorithms.online_storage import OnlineStorage
from algorithms.ppo import PPO
from environments.parallel_envs import make_vec_envs
from models.policy import Policy
from utils import evaluation as utl_eval
from utils import helpers as utl
from utils.tb_logger import TBLogger
from vae import VaribadVAE
from vae_mixture import VaribadVAEMixture
from vae_dme import VaribadVAEDME
from models.policy_encoder import PolicyEncoder
import torch.nn.functional as F

import metaworld
import random
import csv

torch.autograd.set_detect_anomaly(True)
from utils.helpers import get_device
from pathlib import Path


class MetaLearnerML10DME:
    """
    Meta-Learner class with the main training loop for DME.
    """
    def __init__(self, args):

        self.args = args
        assert self.args.vae_mixture_num>1
        self.virtual_ratio = self.args.virtual_ratio
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # calculate number of updates and keep count of frames/iterations
        self.num_updates = int(args.num_frames) // args.policy_num_steps // args.num_processes

        if self.args.load_dir is None:
            self.frames = 0
            self.iter_idx = -1
        else:
            self.frames = (int(self.args.load_iter)+1) * self.args.policy_num_steps * self.args.num_processes
            self.args.precollect_len = self.args.precollect_len + self.frames
            self.iter_idx = int(self.args.load_iter)

        self.recent_train_success = np.zeros(10)
        self.task_count = np.zeros(10)

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        header = ['iter', 'frames', 'R0', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12', 'R13', 'R14',
                  'S0', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9', 'S10', 'S11', 'S12', 'S13', 'S14',
                  'SF0', 'SF1', 'SF2', 'SF3', 'SF4', 'SF5', 'SF6', 'SF7', 'SF8', 'SF9', 'SF10', 'SF11', 'SF12', 'SF13', 'SF14',
                  'RF0', 'RF1', 'RF2', 'RF3', 'RF4', 'RF5', 'RF6', 'RF7', 'RF8', 'RF9', 'RF10', 'RF11', 'RF12', 'RF13', 'RF14',
                  ]
        with open(self.logger.full_output_folder+'/log_eval.csv', 'w', encoding='UTF8') as f:
            writer = csv.writer(f)
            writer.writerow(header)

        self.train_tasks = None


        # initialise environments
        self.envs = make_vec_envs(env_name=args.env_name, seed=args.seed, num_processes=args.num_processes,
                                  gamma=args.policy_gamma, device=get_device(),
                                  episodes_per_task=self.args.max_rollouts_per_task,
                                  normalise_rew=args.norm_rew_for_policy, ret_rms=None,
                                  tasks=None
                                  )

        # calculate what the maximum length of the trajectories is
        self.args.max_trajectory_len = self.envs._max_episode_steps
        self.args.max_trajectory_len *= self.args.max_rollouts_per_task


        # get policy input dimensions
        self.args.state_dim = self.envs.observation_space.shape[0]
        self.args.task_dim = self.envs.task_dim

        self.args.belief_dim = self.envs.belief_dim
        self.args.num_states = self.envs.num_states
        # get policy output (action) dimensions
        self.args.action_space = self.envs.action_space

        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.envs.action_space.shape[0]

        # initialise GMVAE
        self.vae = VaribadVAEDME(self.args, self.logger, lambda: self.iter_idx)

        # UPDATE: load pretrainer directories
        if (
            getattr(self.args, 'pretrainer', None)
            and getattr(self.args, 'pretrain_env_name',    None)
            and getattr(self.args, 'pretrain_exp_name', None)
            and getattr(self.args, 'pretrain_seed',     None) is not None
            and getattr(self.args, 'pretrain_frames',   None) is not None
        ):
            # build checkpoint path from the new args
            ckpt = (
                Path("logs") / "skill_distill_models" /
                self.args.pretrainer      /
                self.args.pretrain_env_name        /
                f"{self.args.pretrain_exp_name}_{self.args.pretrain_seed}" /
                str(self.args.pretrain_frames)
            ).expanduser().resolve()
            
            # 1.  decoders 
            self.vae.state_decoder.load_state_dict(
                torch.load(ckpt / "vae_state_decoder.pt", map_location=get_device())
            )
            self.vae.reward_decoder.load_state_dict(
                torch.load(ckpt / "vae_reward_decoder.pt", map_location=get_device())
            )
                        
            # 2.  per-skill Gaussian priors  (μ_y, σ²_y)
            prior = torch.load(ckpt / "vae_prior.pt", map_location=get_device())
            self.vae.set_skill_prior(prior)          # <- helper you added earlier

            # 3.  optional style encoder
            style_ckpt = ckpt / "vae_style_encoder.pt"
            if style_ckpt.exists():
                self.vae.load_style_net(style_ckpt)

            print(f"[DME] loaded skill-aware decoder and priors from {ckpt}")

        if self.args.policy_separate_gru:
            self.encoder_pol= PolicyEncoder(self.args, self.logger, lambda: self.iter_idx)
        else:
            self.encoder_pol = None

        self.policy_storage = self.initialise_policy_storage()
        self.policy = self.initialise_policy()
        #self.policy_resample = PolicyResample(self.args, self.args.vae_mixture_num).to(get_device())

        if self.args.load_dir is not None:
            print('loading pretrained model from ', self.args.load_dir)
            #self.policy.actor_critic = torch.load(self.args.load_dir+'/models/policy{}.pt'.format(self.args.load_iter))
            self.policy.actor_critic.load_state_dict(torch.load(self.args.load_dir+'/models/policy{}.pt'.format(self.args.load_iter)).state_dict())
            self.policy.actor_critic.train()
            print('policy loaded')

            self.vae.encoder.load_state_dict(torch.load(self.args.load_dir+'/models/encoder{}.pt'.format(self.args.load_iter)).state_dict())
            self.vae.encoder.train()
            print('vae.encoder loaded')
            if self.encoder_pol is not None:
                self.encoder_pol = torch.load(self.args.load_dir + '/models/encoder_pol{}.pt'.format(self.args.load_iter))
                self.encoder_pol.train()
                self.encoder_pol.optimiser_vae.load_state_dict(torch.load(self.args.load_dir + '/models/encoder_pol_optimiser_pol{}.pt'.format(self.args.load_iter)))
                print('encoder_pol loaded')
            if self.vae.state_decoder is not None:
                self.vae.state_decoder.load_state_dict(torch.load(self.args.load_dir+'/models/state_decoder{}.pt'.format(self.args.load_iter)).state_dict())
                self.vae.state_decoder.train()
                print('vae.state_decoder loaded')
            if self.vae.reward_decoder is not None:
                self.vae.reward_decoder.load_state_dict(torch.load(self.args.load_dir+'/models/reward_decoder{}.pt'.format(self.args.load_iter)).state_dict())
                self.vae.reward_decoder.train()
                print('vae.reward_decoder loaded')
            if self.vae.task_decoder is not None:
                self.vae.task_decoder.load_state_dict(torch.load(self.args.load_dir+'/models/task_decoder{}.pt'.format(self.args.load_iter)).state_dict())
                self.vae.task_decoder.train()
                print('vae.task_decoder loaded')
            self.vae.optimiser_vae.load_state_dict(torch.load(self.args.load_dir+'/models/optimiser_vae{}.pt'.format(self.args.load_iter)))
            self.policy.optimiser.load_state_dict(torch.load(self.args.load_dir+'/models/optimiser_pol{}.pt'.format(self.args.load_iter)))

            if self.args.norm_rew_for_policy:
                rew_rms = utl.load_obj(self.args.load_dir + 'models/', 'env_rew_rms{}'.format(self.args.load_iter))
                self.envs.venv.ret_rms = rew_rms
            if self.args.norm_state_for_policy:
                obs_rms = utl.load_obj(self.args.load_dir + 'models/', 'pol_state_rms{}'.format(self.args.load_iter))
                self.policy.actor_critic.state_rms = obs_rms

    def initialise_policy_storage(self):
        return OnlineStorage(args=self.args,
                             num_steps=self.args.policy_num_steps,
                             num_processes=self.args.num_processes,
                             state_dim=self.args.state_dim,
                             latent_dim=self.args.latent_dim,
                             belief_dim=self.args.belief_dim,
                             task_dim=self.args.task_dim,
                             prob_dim=self.args.vae_mixture_num,
                             action_space=self.args.action_space,
                             hidden_size=self.args.encoder_gru_hidden_size,
                             normalise_rewards=self.args.norm_rew_for_policy,
                             )

    def initialise_policy(self):

        # initialise policy network
        policy_net = Policy(
            args=self.args,
            #
            pass_state_to_policy=self.args.pass_state_to_policy,
            pass_latent_to_policy=self.args.pass_latent_to_policy,
            pass_belief_to_policy=self.args.pass_belief_to_policy,
            pass_task_to_policy=self.args.pass_task_to_policy,
            pass_prob_to_policy=self.args.pass_prob_to_policy,
            dim_state=self.args.state_dim,
            dim_latent=self.args.latent_dim * 2,
            dim_belief=self.args.belief_dim,
            dim_task=self.args.task_dim,
            #
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            policy_initialisation=self.args.policy_initialisation,
            #
            action_space=self.envs.action_space,
            init_std=self.args.policy_init_std,
            min_std=self.args.policy_min_std,
            max_std=self.args.policy_max_std,
            # w (style latent) support
            pass_w_to_policy=getattr(self.args, 'pass_w_to_policy', False),
            dim_w=self.args.latent_dim,  # w has same dimension as latent_dim
        ).to(get_device())

        # initialise policy trainer
        policy = PPO(
            self.args,
            policy_net,
            self.args.policy_value_loss_coef,
            self.args.policy_entropy_coef,
            policy_optimiser=self.args.policy_optimiser,
            policy_anneal_lr=self.args.policy_anneal_lr,
            train_steps=self.num_updates,
            lr=self.args.lr_policy,
            eps=self.args.policy_eps,
            ppo_epoch=self.args.ppo_num_epochs,
            num_mini_batch=self.args.ppo_num_minibatch,
            use_huber_loss=self.args.ppo_use_huberloss,
            use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
            clip_param=self.args.ppo_clip_param,
            optimiser_vae=self.vae.optimiser_vae,
            optimiser_encoder_pol=self.encoder_pol.optimiser_vae if self.encoder_pol is not None else None,
            grad_correction=self.args.grad_correction
        )

        return policy

    def sample_virtual_task_dme(self, num_virtual_skills):
        """
        Sample virtual task variables for imagination following GMVAE generative process.
        Returns just the intercepts needed for virtual training.
        
        Returns: (w_intercept, y_intercept) with shapes:
        - w_intercept: (num_virtual_skills, latent_dim)
        - y_intercept: (num_virtual_skills, vae_mixture_num) 
        """
        w, y_soft, z, mu_z, logvar_z = self.vae.encoder.sample_virtual_task(num_virtual_skills)
        return w, y_soft

    def train(self):
        """ Main Meta-Training loop """
        start_time = time.time()
        if getattr(self.args, 'debug_time', False):
            print(f"[DME DEBUG] Starting train() method")

        # reset environments
        prev_state, belief, task = utl.reset_env(self.envs, self.args)
        
        # Sample initial virtual task using GMVAE
        w_intercept, y_intercept = self.sample_virtual_task_dme(self.args.num_processes)

        self.policy_storage.prev_state[0].copy_(prev_state)
        
        # log once before training
        if not getattr(self.args, 'debug_no_log', False):
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Initial pre-log setup time took: {time.time() - start_time:.4f}s")
            initial_log_start_time = time.time()
            with torch.no_grad():
                self.log(None, None, start_time)
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Initial log call took: {time.time() - initial_log_start_time:.4f}s")
        else:
            print("[DEBUG] Skipping initial log before training...")
            
        self.iter_idx += 1 # number of interactions with the real environment
        self.virtual_iter_idx = self.iter_idx #total interactions including the virtual
        
        while self.iter_idx < self.num_updates:
            iter_start_time = time.time()
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Starting iteration {self.iter_idx}")
            
            if random.random() < self.virtual_ratio: #CAUTION: this code is valid only when policy_num_steps is multiple of 5000
                virtual = True
                if getattr(self.args, 'debug_time', False):
                    print(f"[DME DEBUG] Using virtual environment this iteration")
            else:
                virtual = False
                if getattr(self.args, 'debug_time', False):
                    print(f"[DME DEBUG] Using real environment this iteration")
                
            # First, re-compute the hidden states given the current rollouts (since the VAE might've changed)
            with torch.no_grad():
                latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w = self.encode_running_trajectory(virtual) #CAUTION: this code is valid only when policy_num_steps is multiple of 5000
                # strictly this is not correct, since the VAE is changing during virtual meta-epi as well,ok only when meta-episode ends every iter
                latent_sample_v = latent_sample.clone().detach()
                w_v = w.clone().detach()

                if self.args.policy_separate_gru:
                    latent_pol, hidden_state_pol = self.encode_running_trajectory_pol(encoder = self.encoder_pol.encoder)
            
            # add this initial latent state to the policy storage
            assert len(self.policy_storage.latent_mean) == 0  # make sure we emptied buffers
            self.policy_storage.hidden_states[0].copy_(hidden_state)
            self.policy_storage.latent_samples.append(latent_sample.clone())
            self.policy_storage.latent_mean.append(latent_mean.clone())
            self.policy_storage.latent_logvar.append(latent_logvar.clone())
            self.policy_storage.prob.append(prob.clone())
            if hasattr(self.policy_storage, 'w') and self.policy_storage.w is not None:
                self.policy_storage.w.append(w.clone())
            if self.args.policy_separate_gru:
                self.policy_storage.latent_pol.append(latent_pol.clone())
                self.policy_storage.hidden_states_pol[0].copy_(hidden_state_pol)
            
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] pre-rollout setup time took: {time.time() - iter_start_time:.4f}s")

            # rollout policies for a few steps
            rollout_start_time = time.time()
            for step in range(self.args.policy_num_steps):
                # sample actions from policy
                with torch.no_grad():
                    value, action = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        state=prev_state,
                        belief=belief,
                        task=task,
                        prob=prob,
                        latent_pol = latent_pol if self.args.policy_separate_gru else None,
                        deterministic=False,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar,
                        w=w,
                    )

                # take step in the environment
                [next_state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step(self.envs, action, self.args)

                # take step in the environment
                if virtual: #use virtual environment
                    with torch.no_grad():
                        rew_raw_pred = self.vae.reward_decoder(latent_sample_v.detach(), next_state, prev_state, action)
                    rew_raw = torch.clamp(rew_raw_pred.clone().detach(), min=0.0, max=10.0)
                    rew_raw_np = rew_raw.cpu().numpy()
                    rew_normalised = self.envs.venv._rewfilt2(rew_raw_np)

                    rew_normalised = torch.from_numpy(rew_normalised)[:,0].unsqueeze(dim=1).float().to(get_device())
                
                done = torch.from_numpy(np.array(done, dtype=int)).to(get_device()).float().view((-1, 1))
                # create mask for episode ends
                masks_done = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(get_device())
                # bad_mask is true if episode ended because time limit was reached
                bad_masks = torch.FloatTensor([[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]).to(get_device())

                with torch.no_grad():
                    # compute next embedding (for next loop and/or value prediction bootstrap)
                    if virtual:
                        # For virtual steps, update with virtual task intercept
                        latent_sample_v, _, _, hidden_state_v, \
                        y_v, prob_v, w_v = utl.update_encoding_dme(
                            encoder=self.vae.encoder,
                            next_obs=next_state,
                            action=action,
                            reward=rew_raw,
                            done=done,
                            hidden_state=hidden_state,
                            w_intercept=w_intercept,
                            y_intercept=y_intercept)
                    
                    latent_sample, latent_mean, latent_logvar, hidden_state, \
                    y, prob, w = utl.update_encoding_dme(
                        encoder=self.vae.encoder,
                        next_obs=next_state,
                        action=action,
                        reward=rew_raw,
                        done=done,
                        hidden_state=hidden_state,
                        w_intercept=None,
                        y_intercept=None)
                        
                    if self.args.policy_separate_gru:
                        latent_pol, hidden_state_pol = utl.update_encoding_pol(
                            encoder=self.encoder_pol.encoder,
                            next_obs=next_state,
                            action=action,
                            reward=rew_raw,
                            done=done,
                            hidden_state=hidden_state_pol,
                        )

                # before resetting, update the embedding and add to vae buffer
                # (last state might include useful task info)
                if not (self.args.disable_decoder and self.args.disable_kl_term):
                    if virtual:
                        self.vae.rollout_storage_virtual.insert(prev_state.clone(),
                                                        action.detach().clone(),
                                                        next_state.clone(),
                                                        rew_raw.clone(),
                                                        done.clone(),
                                                        task.clone() if task is not None else None)
                    else:
                        self.vae.rollout_storage.insert(prev_state.clone(),
                                                        action.detach().clone(),
                                                        next_state.clone(),
                                                        rew_raw.clone(),
                                                        done.clone(),
                                                        task.clone() if task is not None else None)
                    
                # add the obs before reset to the policy storage
                self.policy_storage.next_state[step] = next_state.clone()

                # reset environments that are done
                done_indices = np.argwhere(done.cpu().flatten()).flatten()

                if len(done_indices) > 0:
                    task_indicies = np.array(self.envs.get_task())[:,0]
                    for i in range(10):
                        self.task_count[i] += np.count_nonzero(task_indicies == i)
                    #TODO1: for virtual envs, we need to store the initial states in a buffer
                    next_state, belief, task = utl.reset_env(self.envs, self.args, indices=done_indices, state=next_state)

                    if virtual:
                        # resample virtual task with new distribution (DME version)
                        w_intercept, y_intercept = self.sample_virtual_task_dme(self.args.num_processes)
                    
                # add experience to policy buffer
                self.policy_storage.insert(
                    state=next_state,
                    belief=belief,
                    task=task,
                    actions=action,
                    rewards_raw=rew_raw,
                    rewards_normalised=rew_normalised,
                    value_preds=value,
                    masks=masks_done,
                    bad_masks=bad_masks,
                    done=done,
                    hidden_states=hidden_state.squeeze(0),
                    latent_sample=latent_sample,
                    latent_mean=latent_mean,
                    latent_logvar=latent_logvar,
                    y=y,
                    prob=prob,
                    w=w,
                    latent_pol=latent_pol if self.args.policy_separate_gru else None,
                    hidden_states_pol=hidden_state_pol.squeeze(0) if self.args.policy_separate_gru else None
                )
                prev_state = next_state
                self.frames += self.args.num_processes
            
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Full rollout took: {time.time() - rollout_start_time:.4f}s")

            # --- UPDATE ---
            if self.args.precollect_len <= self.frames:
                update_start_time = time.time()

                # check if we are pre-training the VAE
                if self.args.pretrain_len > self.iter_idx:
                    for p in range(self.args.num_vae_updates_per_pretrain):
                        self.vae.compute_vae_loss(update=True,
                                                  pretrain_index=self.iter_idx * self.args.num_vae_updates_per_pretrain + p)
                # otherwise do the normal update (policy + vae)
                else:
                    train_stats = self.update(state=prev_state,
                                              belief=belief,
                                              task=task,
                                              prob=prob,
                                              latent_pol=latent_pol if self.args.policy_separate_gru else None,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar,
                                              w=w)

                    # log
                    run_stats = [action, self.policy_storage.action_log_probs, value]
                    
                    if getattr(self.args, 'debug_time', False):
                        print(f"[DME DEBUG] Update section took: {time.time() - update_start_time:.4f}s")
                        
                    log_start_time = time.time()
                    with torch.no_grad():
                        self.log(run_stats, train_stats, start_time)
                    
                    if getattr(self.args, 'debug_time', False):
                        print(f"[DME DEBUG] Logging took: {time.time() - log_start_time:.4f}s")
                
            # clean up after update
            cleanup_start_time = time.time()
            self.policy_storage.after_update()
            
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Cleanup took: {time.time() - cleanup_start_time:.4f}s")

            self.virtual_iter_idx += 1
            self.iter_idx += 1
            self.virtual_ratio += self.args.virtual_ratio_increment*(self.args.num_processes*self.args.policy_num_steps)/1e8
            
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Full iteration {self.iter_idx-1} took: {time.time() - iter_start_time:.4f}s")
                print(f"[DME DEBUG] --------------")
        self.envs.close()

    def encode_running_trajectory(self, virtual = False):
        """
        Re-encode the trajectories that the policy has taken so far.
        Returns: latent_sample, latent_mean, latent_logvar, hidden_state, y, z, mu, var, logits, prob, w
        Adapted from SDVT to include all variables for interface compatibility.
        """
        # for each process, get the current batch (zero-padded obs/act/rew + length indicators)
        if virtual:
            prev_state, next_state, action, reward, lens = self.vae.rollout_storage_virtual.get_running_batch()
        else:
            prev_state, next_state, action, reward, lens = self.vae.rollout_storage.get_running_batch()

        # pass through encoder - GMVAE returns: z, z_mu, z_logvar, hidden_state, y_probs, prob, w
        latent_sample, latent_mean, latent_logvar, all_hidden_states, \
            y, prob, w = self.vae.encoder(actions=action,
                                          states=next_state,
                                          rewards=reward,
                                          hidden_state=None,
                                          return_prior=True)

        # get the embedding / hidden state of the current timestep (need to do this since we zero-padded)
        latent_sample = (torch.stack([latent_sample[lens[i]][i] for i in range(len(lens))])).to(get_device())
        latent_mean = (torch.stack([latent_mean[lens[i]][i] for i in range(len(lens))])).to(get_device())
        latent_logvar = (torch.stack([latent_logvar[lens[i]][i] for i in range(len(lens))])).to(get_device())
        hidden_state = (torch.stack([all_hidden_states[lens[i]][i] for i in range(len(lens))])).to(get_device())
        
        # Extract GMVAE variables at final step
        y = (torch.stack([y[lens[i]][i] for i in range(len(lens))])).to(get_device())
        prob = (torch.stack([prob[lens[i]][i] for i in range(len(lens))])).to(get_device())
        w = (torch.stack([w[lens[i]][i] for i in range(len(lens))])).to(get_device())

        return latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w

    def encode_running_trajectory_pol(self, encoder, virtual = False):
        # for each process, get the current batch (zero-padded obs/act/rew + length indicators)
        if virtual:
            prev_state, next_state, action, reward, lens = self.vae.rollout_storage_virtual.get_running_batch()
        else:
            prev_state, next_state, action, reward, lens = self.vae.rollout_storage.get_running_batch()

        # pass through encoder (outputs should be: [batch_size, num_steps, latent_dim])
        latent_sample, latent_mean, latent_logvar, output = encoder(actions=action,
                                                                    states=next_state,
                                                                    rewards=reward,
                                                                    hidden_state=None,
                                                                    return_prior=False)

        # get the embedding / hidden state of the current head (already embedded in the output)
        assert len(latent_sample.shape) == 3  # [num_processes, num_steps, latent_dim]
        assert len(latent_mean.shape) == 3  # [num_processes, num_steps, latent_dim]
        assert len(latent_logvar.shape) == 3  # [num_processes, num_steps, latent_dim]

        # take the final step of each process
        latent_sample = latent_sample[range(len(lens)), lens - 1, :]  # [num_processes, latent_dim]
        latent_mean = latent_mean[range(len(lens)), lens - 1, :]  # [num_processes, latent_dim]
        latent_logvar = latent_logvar[range(len(lens)), lens - 1, :]  # [num_processes, latent_dim]

        return latent_sample, latent_mean, latent_logvar, output

    def get_value(self, state, belief, task, prob, latent_pol, latent_sample, latent_mean, latent_logvar, w=None):
        latent = utl.get_latent_for_policy(self.args, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar)
        return self.policy.actor_critic.get_value(state=state, belief=belief, task=task, latent=latent, prob=prob, latent_pol=latent_pol, w=w).detach()

    def update(self, state, belief, task, prob, latent_pol, latent_sample, latent_mean, latent_logvar, w=None):
        # bootstrap next value prediction
        # update policy (if we are not pre-training, have enough data in the vae buffer, and are not at iteration 0)
        if self.iter_idx >= self.args.pretrain_len and self.iter_idx > 0:

            # bootstrap next value prediction
            with torch.no_grad():
                next_value = self.get_value(state=state,
                                            belief=belief,
                                            task=task,
                                            prob=prob,
                                            latent_pol = latent_pol,
                                            latent_sample=latent_sample,
                                            latent_mean=latent_mean,
                                            latent_logvar=latent_logvar,
                                            w=w)

            # compute returns for current rollouts
            self.policy_storage.compute_returns(next_value, self.args.policy_use_gae, self.args.policy_gamma,
                                                self.args.policy_tau,
                                                use_proper_time_limits=self.args.use_proper_time_limits)

            # update agent (this will also call the VAE update!)
            policy_train_stats = self.policy.update(
                policy_storage=self.policy_storage,
                encoder=self.vae.encoder,
                encoder_pol = self.encoder_pol.encoder if self.args.policy_separate_gru else None,
                rlloss_through_encoder=self.args.rlloss_through_encoder,
                policy_separate_gru = self.args.policy_separate_gru,
                compute_vae_loss=self.vae.compute_vae_loss)
        else:
            policy_train_stats = 0, 0, 0, 0

            # pre-train the VAE
            if self.iter_idx < self.args.pretrain_len:
                self.vae.compute_vae_loss(update=True)

        return policy_train_stats


    def log(self, run_stats, train_stats, start_time):
        # --- visualise behaviour of policy ---

        # --- evaluate policy ----
        if (self.iter_idx + 1) % self.args.eval_interval == 0:
            eval_start_time = time.time()
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Starting evaluation phase")
            
            os.makedirs('{}/{}'.format(self.logger.full_output_folder, self.iter_idx))
            ret_rms = None #we don't need normalised reward for eval
            total_parametric_num = self.args.parametric_num

            # Setup arrays
            setup_start_time = time.time()
            num_worker = 10
            returns_array = np.zeros((15, total_parametric_num, self.args.max_rollouts_per_task))
            latent_means_array = np.zeros((15, total_parametric_num, self.args.latent_dim))
            latent_logvars_array = np.zeros((15, total_parametric_num, self.args.latent_dim))
            w_array = np.zeros((15, total_parametric_num, self.args.latent_dim))  # Add w logging

            successes_array = np.zeros((15, total_parametric_num))
            save_episode_successes = True
            if save_episode_successes:
                episode_successes_array = np.zeros((15, total_parametric_num, self.args.max_rollouts_per_task))

            save_episode_probs = False
            #save_episode_probs = (self.iter_idx + 1) % (20 * self.args.eval_interval) == 0
            probs_array = np.zeros((15, total_parametric_num, self.args.vae_mixture_num))
            if save_episode_probs:
                episode_probs_array = np.zeros((15, total_parametric_num, self.args.max_rollouts_per_task,
                                                self.envs._max_episode_steps, self.args.vae_mixture_num))
            
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Array setup took: {time.time() - setup_start_time:.4f}s")

            # Main evaluation loop
            eval_loop_start_time = time.time()
            for task_class in range(15):
                task_class_start_time = time.time()
                for parametric_num in range(total_parametric_num // num_worker):
                    batch_start_time = time.time()
                    task_list = np.concatenate((np.expand_dims(np.repeat(task_class, num_worker), axis=1),
                                                np.expand_dims(np.arange(num_worker * parametric_num,
                                                                         num_worker * (parametric_num + 1)), axis=1)),axis=1)

                    eval_call_start_time = time.time()
                    returns_per_episode, latent_mean, latent_logvar, successes, prob, episode_probs, episode_successes, w = utl_eval.evaluate_metaworld_dme(
                        args=self.args,
                        policy=self.policy,
                        ret_rms=ret_rms,
                        encoder=self.vae.encoder,
                        encoder_pol=self.encoder_pol.encoder if self.encoder_pol is not None else None,
                        iter_idx=self.iter_idx,
                        tasks=None,
                        test=False,
                        task_list=task_list,
                        save_episode_probs=save_episode_probs,
                        save_episode_successes=save_episode_successes,
                        )
                    
                    if getattr(self.args, 'debug_time', False):
                        print(f"[DME DEBUG] evaluate_metaworld_DME call for task {task_class}, batch {parametric_num} took: {time.time() - eval_call_start_time:.4f}s")

                    # Store results
                    returns_array[task_class, parametric_num * num_worker:(parametric_num + 1) * num_worker, :] = returns_per_episode
                    latent_means_array[task_class, parametric_num * num_worker:(parametric_num + 1) * num_worker, :] = latent_mean
                    latent_logvars_array[task_class, parametric_num * num_worker:(parametric_num + 1) * num_worker, :] = latent_logvar
                    w_array[task_class, parametric_num * num_worker:(parametric_num + 1) * num_worker, :] = w  # Store w values
                    successes_array[task_class, parametric_num * num_worker:(parametric_num + 1) * num_worker] = successes
                    probs_array[task_class, parametric_num * num_worker:(parametric_num + 1) * num_worker, :] = prob
                    if save_episode_probs:
                        episode_probs_array[task_class, parametric_num * num_worker:(parametric_num + 1) * num_worker, :, :, :] = episode_probs
                    if save_episode_successes:
                        episode_successes_array[task_class, parametric_num * num_worker:(parametric_num + 1) * num_worker, :] = episode_successes
                    
                    if getattr(self.args, 'debug_time', False):
                        print(f"[DME DEBUG] Full batch took: {time.time() - batch_start_time:.4f}s")
                
                if getattr(self.args, 'debug_time', False):
                    print(f"[DME DEBUG] Task class {task_class} took: {time.time() - task_class_start_time:.4f}s")
            
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Full evaluation loop took: {time.time() - eval_loop_start_time:.4f}s")

            # Process results
            taskwise_mean_return = np.mean(np.mean(returns_array, axis=2), axis=1)
            taskwise_mean_final_return = np.mean(returns_array[:,:,-1], axis=1)
            taskwise_mean_success = np.mean(successes_array, axis=1)
            taskwise_mean_final_success = np.mean(episode_successes_array[:,:,-1], axis=1)

            # Print results
            print(f"Updates {self.iter_idx}, "
                  f"Frames {self.frames}, "
                  f"FPS {int(self.frames / (time.time() - start_time))}, \n"
                  f" Mean return per episode (train): {np.mean(taskwise_mean_return[:10])},"
                  f" Mean return per episode (test): {np.mean(taskwise_mean_return[10:])},\n"
                  f" Mean final return per episode (train): {np.mean(taskwise_mean_final_return[:10])},"
                  f" Mean final return per episode (test): {np.mean(taskwise_mean_final_return[10:])},\n"
                  f" Mean success rate (train): {np.mean(taskwise_mean_success[:10])},"
                  f" Mean final success rate (train): {np.mean(taskwise_mean_final_success[:10])},\n"
                  f" Mean success rate (test): {np.mean(taskwise_mean_success[10:])}"
                  f" Mean final success rate (test): {np.mean(taskwise_mean_final_success[10:])}"
                  )
            print("history: ", self.task_count)
            print("train taskwise success rates: ", taskwise_mean_success[:10])
            print("train taskwise final success rates: ", taskwise_mean_final_success[:10])
            print("test taskwise success rates: ", taskwise_mean_success[10:])
            print("test taskwise final success rates: ", taskwise_mean_final_success[10:])

            # Save CSV
            with open(self.logger.full_output_folder + '/log_eval.csv', 'a', encoding='UTF8') as f:
                writer = csv.writer(f)
                writer.writerow(np.concatenate(([self.iter_idx, int(self.frames)], taskwise_mean_return, taskwise_mean_success, taskwise_mean_final_success, taskwise_mean_final_return)))
            
            # Save numpy arrays
            np.save('{}/{}/returns.npy'.format(self.logger.full_output_folder, self.iter_idx), returns_array)
            np.save('{}/{}/latent_means.npy'.format(self.logger.full_output_folder, self.iter_idx), latent_means_array)
            np.save('{}/{}/latent_logvars.npy'.format(self.logger.full_output_folder, self.iter_idx),
                    latent_logvars_array)
            np.save('{}/{}/w_array.npy'.format(self.logger.full_output_folder, self.iter_idx), w_array)  # Save w values
            np.save('{}/{}/successes.npy'.format(self.logger.full_output_folder, self.iter_idx), successes_array)
            np.save('{}/{}/task_count.npy'.format(self.logger.full_output_folder, self.iter_idx), self.task_count)
            if save_episode_successes:
                np.save('{}/{}/episode_successes_array.npy'.format(self.logger.full_output_folder, self.iter_idx),
                        episode_successes_array)

            np.save('{}/{}/probs.npy'.format(self.logger.full_output_folder, self.iter_idx), probs_array)
            if save_episode_probs:
                np.save('{}/{}/episode_probs_array.npy'.format(self.logger.full_output_folder, self.iter_idx),
                        episode_probs_array)
            
            self.task_count = np.zeros((10))
            self.recent_train_success = taskwise_mean_success[:10]
            
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Full evaluation phase took: {time.time() - eval_start_time:.4f}s")
        
        # --- save models ---
        if (self.iter_idx + 1) % self.args.save_interval == 0:
            model_save_start_time = time.time()
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Starting model save phase")
            
            save_path = os.path.join(self.logger.full_output_folder, 'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)

            idx_labels = ['']
            if self.args.save_intermediate_models:
                idx_labels.append(int(self.iter_idx))

            for idx_label in idx_labels:

                torch.save(self.policy.actor_critic, os.path.join(save_path, f"policy{idx_label}.pt"))
                torch.save(self.vae.encoder, os.path.join(save_path, f"encoder{idx_label}.pt"))

                if self.vae.state_decoder is not None:
                    torch.save(self.vae.state_decoder, os.path.join(save_path, f"state_decoder{idx_label}.pt"))
                if self.vae.reward_decoder is not None:
                    torch.save(self.vae.reward_decoder, os.path.join(save_path, f"reward_decoder{idx_label}.pt"))
                if self.vae.task_decoder is not None:
                    torch.save(self.vae.task_decoder, os.path.join(save_path, f"task_decoder{idx_label}.pt"))
                if self.encoder_pol is not None:
                    torch.save(self.encoder_pol.encoder, os.path.join(save_path, f"encoder_pol{idx_label}.pt"))
                    torch.save(self.encoder_pol.optimiser_vae.state_dict(), os.path.join(save_path, f"encoder_pol_optimiser_pol{idx_label}.pt"))
                torch.save(self.vae.optimiser_vae.state_dict(), os.path.join(save_path, f"optimiser_vae{idx_label}.pt"))
                torch.save(self.policy.optimiser.state_dict(), os.path.join(save_path, f"optimiser_pol{idx_label}.pt"))
                # save normalisation params of envs
                if self.args.norm_rew_for_policy:
                    rew_rms = self.envs.venv.ret_rms
                    utl.save_obj(rew_rms, save_path, f"env_rew_rms{idx_label}")
                # TODO: grab from policy and save?
                if self.args.norm_state_for_policy:
                    obs_rms = self.policy.actor_critic.state_rms
                    utl.save_obj(obs_rms, save_path, f"pol_state_rms{idx_label}")
            
            if getattr(self.args, 'debug_time', False):
                print(f"[DME DEBUG] Model save phase took: {time.time() - model_save_start_time:.4f}s")


        # --- log some other things ---
        if train_stats is not None and self.iter_idx>0:
        #if ((self.iter_idx + 1) % self.args.log_interval == 0) and (train_stats is not None):

            self.logger.add('environment/state_max', self.policy_storage.prev_state.max(), self.iter_idx)
            self.logger.add('environment/state_min', self.policy_storage.prev_state.min(), self.iter_idx)

            self.logger.add('environment/rew_max', self.policy_storage.rewards_raw.max(), self.iter_idx)
            self.logger.add('environment/rew_min', self.policy_storage.rewards_raw.min(), self.iter_idx)

            self.logger.add('policy_losses/value_loss', train_stats[0], self.iter_idx)
            self.logger.add('policy_losses/action_loss', train_stats[1], self.iter_idx)
            self.logger.add('policy_losses/dist_entropy', train_stats[2], self.iter_idx)
            self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx)

            self.logger.add('policy/action', run_stats[0][0].float().mean(), self.iter_idx)
            if hasattr(self.policy.actor_critic, 'logstd'):
                self.logger.add('policy/action_logstd', self.policy.actor_critic.dist.logstd.mean(), self.iter_idx)
            self.logger.add('policy/action_logprob', run_stats[1].mean(), self.iter_idx)
            self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx)

            self.logger.add('encoder/latent_mean', torch.cat(self.policy_storage.latent_mean).mean(), self.iter_idx)
            self.logger.add('encoder/latent_logvar', torch.cat(self.policy_storage.latent_logvar).mean(), self.iter_idx)
            
            # Add w logging for DME
            if hasattr(self.policy_storage, 'w') and self.policy_storage.w:
                self.logger.add('encoder/w_mean', torch.cat(self.policy_storage.w).mean(), self.iter_idx)

            # log the average weights and gradients of all models (where applicable)
            for [model, name] in [
                [self.policy.actor_critic, 'policy'],
                [self.vae.encoder, 'encoder'],
                [self.vae.reward_decoder, 'reward_decoder'],
                [self.vae.state_decoder, 'state_transition_decoder'],
                [self.vae.task_decoder, 'task_decoder'],
                [self.encoder_pol.encoder, 'policy_encoder'] if self.encoder_pol is not None else [None, None]
            ]:
                if model is not None:
                    param_list = list(model.parameters())
                    param_mean = np.mean([param_list[i].data.cpu().numpy().mean() for i in range(len(param_list))])

                    self.logger.add('weights/{}'.format(name), param_mean, self.iter_idx)
                    if name == 'policy':
                        self.logger.add('weights/policy_std', param_list[0].data.mean(), self.iter_idx)
                    if param_list[0].grad is not None:
                        param_grad_mean = np.mean([param_list[i].grad.cpu().numpy().mean() for i in range(len(param_list))])
                        self.logger.add('gradients/{}'.format(name), param_grad_mean, self.iter_idx)

