import os
from pathlib import Path 
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.dataset.robomimic_replay_image_dataset import _convert_actions, undo_transform_action
from diffusion_policy.env_runner.robomimic_image_runner import create_env
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
from dino_wm.datasets.pusht_dset_my import PushTImageDynamicsModelDataset
from dino_wm.datasets.robomimic_dset import RobomimicImageDynamicsModelDataset
import robomimic.utils.file_utils as FileUtils
from diffusion_policy.dataset.base_dataset import BaseImageDataset
from diffusion_policy.model.common.normalizer import LinearNormalizer
from dino_wm.datasets.img_transforms import default_transform, get_eval_crop_transform, get_eval_crop_transform_resnet
from dino_wm.plan import load_model
import torch.optim as optim
import torch
from torch import nn
from torch.utils.data import DataLoader
from accelerate import Accelerator

import hydra
from omegaconf import OmegaConf, open_dict
from hydra.utils import instantiate
import copy
import numpy as np
import torch.linalg as linalg
import torch.nn.functional as F
from einops import rearrange
from torchvision import utils
from torchvision.utils import save_image
from torchvision.transforms.functional import to_pil_image, to_tensor
from PIL import ImageDraw, ImageFont

def add_text_to_image(img_tensor, text):
    img_min, img_max = img_tensor.min(), img_tensor.max()
    img_tensor = (img_tensor - img_min) / (img_max - img_min)
    img = to_pil_image(img_tensor)
    draw = ImageDraw.Draw(img)    
    width, height = img.size
    text_pos = (width - 50, 50)
    try:
        font = ImageFont.truetype("arial.ttf", 40)
    except IOError:
        font = ImageFont.load_default()
    draw.text(text_pos, f"{text:.2f}", fill="black", font=font)
    img_tensor = to_tensor(img)
    img_tensor = img_tensor * (img_max - img_min) + img_min
    
    return img_tensor


class PushTPlanner:
    def __init__(
        self,
        demo_dataset_config,
        dynamics_model_ckpt,
        decoder_path,
        value_func_path=None,
        action_step=8,
        output_dir='debug/'

    ):
        self.accelerator = Accelerator()
        self.device = self.accelerator.device

        # demo dataset
        self.demo_dataset_config = demo_dataset_config
        # wm
        dynamics_model_dir = os.path.dirname(os.path.dirname(dynamics_model_ckpt))
        with open(os.path.join(dynamics_model_dir, "hydra.yaml"), "r") as f:
            model_cfg = OmegaConf.load(f)
            # assert model_cfg.abs_action == self.demo_dataset_config.abs_action
        self.wm = load_model(Path(dynamics_model_ckpt), model_cfg, 1, device=self.device)
        if not model_cfg.model.train_encoder and model_cfg.encoder_ckpt_path is not None:
            encoder_ckpt = torch.load(model_cfg.encoder_ckpt_path, map_location='cuda')
            self.wm.encoder.load_state_dict(encoder_ckpt['encoder'])
            print('loaded encoder from ', model_cfg.encoder_ckpt_path)

        ###### temp, for debugging #####
        if decoder_path is not None:
            decoder_ckpt = torch.load(decoder_path, map_location='cuda')
            self.wm.decoder = hydra.utils.instantiate(
                model_cfg.decoder, 
                emb_dim=self.wm.encoder.emb_dim
            )
            self.wm.decoder.load_state_dict(decoder_ckpt['decoder'])
            print('loaded decoder from ', decoder_path)

        self.wm = self.accelerator.prepare(self.wm)
        self.wm.disable_reconstruction = False
        self.wm.eval()

        if value_func_path is not None:
            value_net = ConvValueNetwork(emb_dim_in=768).to('cuda')
            value_net_checkpoint = torch.load(value_func_path, map_location='cuda')
            value_net.load_state_dict(value_net_checkpoint['model_state_dict'])
            self.value_net = self.accelerator.prepare(value_net)
            self.value_net.eval()
        else:
            self.value_net = None
        # normalizer
        wm_normalizer = LinearNormalizer()
        wm_normalizer.load_state_dict(torch.load(os.path.join(dynamics_model_dir, "normalizer.pth")))
        self.wm_normalizer = wm_normalizer.to(self.device)
        self.policy_action_normalizer = LinearNormalizer()

        self.original_action_dim = 2
        self.abs_action = True  
            
        self.use_crop = model_cfg.use_crop
        self.use_resnet_encoder = model_cfg.use_resnet_encoder
        self.original_img_size = 140
        self.transformed_img_size = 128
        if self.use_crop:
            if self.use_resnet_encoder:
                self.img_transform = get_eval_crop_transform_resnet(self.original_img_size)
            else:
                self.img_transform = get_eval_crop_transform(self.original_img_size)

        self.view_names = model_cfg.view_names

        self.frameskip = model_cfg.frameskip
        self.action_dim = 2
        self.exec_step = action_step
        print('self.demo_dataset_config.horizon ', self.demo_dataset_config.horizon)
        if self.demo_dataset_config.horizon == 32:
            self.horizon = 2
        elif self.demo_dataset_config.horizon == 16:
            self.horizon = 1

        self.get_demo_latents()

        self.timestep = 0
        self.output_dir = output_dir
        self.image_output_dir = os.path.join(output_dir, 'images')
        os.makedirs(self.image_output_dir, exist_ok=True)
        self.idx = 0

    def set_policy_action_normalizer(self, policy_action_normalizer):
        self.policy_action_normalizer = policy_action_normalizer


    def get_demo_latents(self,):
        demo_dataset: BaseImageDataset
        original_val_ratio = self.demo_dataset_config.val_ratio
        self.demo_dataset_config.zarr_path = 'path_to_expert_dataset'
        self.demo_dataset_config.val_ratio = original_val_ratio

        self.demo_dataset_config.horizon = 1
        # self.demo_dataset_config.n_obs_steps = 1
        self.demo_dataset_config.pad_before = 0
        self.demo_dataset_config.pad_after = 0

        demo_dataset = hydra.utils.instantiate(self.demo_dataset_config)
        print('len(demo_dataset) ', len(demo_dataset))
        demo_loader = DataLoader(demo_dataset, batch_size=64, shuffle=False, num_workers=4)

        demo_visual_latents = []
        demo_proprio_latents = []
        demo_images = []
        with torch.no_grad():
            for batch_idx, batch in enumerate(demo_loader):
                obs = batch['obs']
                obs = dict_apply(obs, lambda x: x[:, -1:, ...])
                proprio = obs['agent_pos']
                
                visual = {'image': obs['image']}

                for view_name in self.view_names:                    
                    visual[view_name] = visual[view_name].to(self.device)
                    if self.use_resnet_encoder:
                        visual[view_name] = self.wm_normalizer[view_name].normalize(visual[view_name])
                    visual[view_name] = self.img_transform(visual[view_name].view(-1, 3, 140, 140))
                    if self.use_resnet_encoder:
                        visual[view_name] = visual[view_name].view(-1, 1, 3, 128, 128)
                    else:
                        visual[view_name] = visual[view_name].view(-1, 1, 3, 140, 140)
                visual_cpu = {k: v.cpu() for k, v in visual.items()}
                demo_images.append(visual_cpu)

                obs_wm = {'visual': visual, 'proprio': proprio}
                obs_wm['proprio'] = self.wm_normalizer['state'].normalize(obs_wm['proprio'])

                encode_obs = self.wm.encode_obs(obs_wm)
                demo_visual_latents.append(encode_obs['visual'].cpu())
                demo_proprio_latents.append(encode_obs['proprio'].cpu())
                torch.cuda.empty_cache()

        self.demo_visual_latents = torch.cat(demo_visual_latents, dim=0)
        self.demo_proprio_latents = torch.cat(demo_proprio_latents, dim=0)

        if len(self.demo_visual_latents.shape) > 2:
            self.demo_visual_latents = self.demo_visual_latents.reshape(self.demo_visual_latents.size(0), -1)

        print('demo_visual_latents shape ', self.demo_visual_latents.shape)
        print('demo_proprio_latents shape ', self.demo_proprio_latents.shape)
        self.demo_images = {
            key: torch.cat([d[key] for d in demo_images], dim=0)
            for key in demo_images[0].keys()
        }
        for key in self.demo_images.keys():
            print(key, self.demo_images[key].shape)

        del demo_dataset

    # deprecated
    def compute_current_reward(self, current_obs):
        with torch.no_grad():
            current_obs_wm = self.prepare_obs(current_obs, 1)
            encode_obs = self.wm.encode_obs(current_obs_wm)
            current_visual_latent = encode_obs['visual'] # (1, 1, 196, 382*2)
            reward, _ = self.compute_nn_reward(current_visual_latent.squeeze(1))
        return reward

    def compute_nn_reward(self, current_visual_latent):
        if len(current_visual_latent.shape) > 2:
            current_visual_latent = current_visual_latent.reshape(current_visual_latent.size(0), -1)

        device = current_visual_latent.device
        chunk_size = 2048  # Adjust this based on your GPU memory constraints.
        batch_size = current_visual_latent.size(0)
        
        global_min_cost = None
        global_min_idx = None
        
        # TODO: try softmax??
        # Process the demo_visual_latents in chunks.
        for start in range(0, self.demo_visual_latents.shape[0], chunk_size):
            demo_chunk = self.demo_visual_latents[start:start+chunk_size].to(device, non_blocking=True)
            # Compute pairwise distances between current_visual_latent and the current chunk.
            dist = torch.cdist(current_visual_latent, demo_chunk, p=2)  # shape: (B, chunk_size)
            cost, idx = dist.min(dim=-1)  # cost: (B,), idx: (B,)
            # else:
            #     diff = current_visual_latent.unsqueeze(1) - demo_chunk.unsqueeze(0)  # Shape: (B, chunk_size, latent_dim)
            #     mse_dist = (diff ** 2).mean(dim=-1)  # MSE loss shape: (B, chunk_size)
            #     cost, idx = mse_dist.min(dim=-1)  # cost: (B,), idx: (B,)
            
            if global_min_cost is None:
                global_min_cost = cost
                global_min_idx = idx + start  # Adjust index offset.
            else:
                # Update the global minimum for each sample.
                mask = cost < global_min_cost
                global_min_cost[mask] = cost[mask]
                global_min_idx[mask] = idx[mask] + start
        
        reward = -global_min_cost
        # nearest_latent = self.demo_visual_latents[global_min_idx]
        # print('mse loss ', F.mse_loss(current_visual_latent, nearest_latent.to(device)).item())
        return reward, global_min_idx

    def plan_shooting(self, current_obs, current_state, init_actions, eval_id=None, timestep=None):
        # init_actions is unnormalized, current_obs is unnormalized
        self.timestep += 1
        init_actions = self.wm_normalizer['act'](init_actions[:, :self.horizon * self.frameskip])
        print('1 init_actions ', init_actions.shape)
        num_samples = init_actions.shape[0]
        bs = 8
        
        assert num_samples == bs
        value = torch.zeros((num_samples), device=self.device)
        for batch in range(num_samples // bs):
            # print('batch*bs:(batch+1)*bs ', batch*bs, (batch+1)*bs)
            action_batch = init_actions[batch*bs:(batch+1)*bs, :, :]
            reward = self.compute_rollout_reward(
                current_obs, action_batch
            )
            value[batch*bs:(batch+1)*bs] = reward
            torch.cuda.empty_cache()
        
        # value = self.compute_rollout_reward(current_obs, init_actions)
        best_value, best_index = value.max(0)
        print('best_value ', best_value.item())
        best_action = init_actions[best_index]
        print('best_action ', best_action.shape)
        # _, _ = self.eval_actions(current_obs, current_state, best_action, timestep=timestep, eval_id=eval_id)
        
        best_action_to_execute = best_action[:self.exec_step]
        best_action_to_execute_unnormalized = self.wm_normalizer['act'].unnormalize(best_action_to_execute)
        return best_action_to_execute_unnormalized.unsqueeze(0), best_index

    # def cond_fn(self, sample, current_obs):
    #     """
    #     Computes the classifier guidance gradient for one-step rollout.
        
    #     Args:
    #         sample (torch.Tensor): normalized noisy action 
    #         current_obs (torch.Tensor): unnormalized obs
            
    #     Returns:
    #         torch.Tensor: The gradient of the classifier loss with respect to the sample.
    #     """
    #     # Use only the first 8 timesteps for the rollout.
    #     init_actions_normalized = sample[:, 1:1+self.horizon * self.frameskip]
    #     # print('init_actions_normalized ', init_actions_normalized.shape)
    #     init_actions_unnormalized = self.policy_action_normalizer.unnormalize(init_actions_normalized)
    #     init_actions = self.wm_normalizer['act'].normalize(init_actions_unnormalized)
    #     action_batch = rearrange(init_actions, 'b (h f) a -> b h (f a)', f=self.frameskip, h=self.horizon)
    #     # print('action_batch ', action_batch.shape)
    #     batch_size = init_actions.shape[0]
        
    #     # Assume self.current_obs is set externally in a differentiable form.
    #     current_obs_wm = self.prepare_obs(current_obs, batch_size)
        
    #     act_0 = action_batch[:, :1, :]  # use the first action

    #     z = self.wm.encode(current_obs_wm, act_0)
        
    #     # One-step rollout of the dynamics model.
    #     z_pred = self.wm.predict(z[:, -1:])
    #     z_new = z_pred[:, -1:, ...]
        
    #     z_obs, _ = self.wm.separate_emb(z_new)
    #     # Compute cost in chunks; demo_visual_latents is kept on CPU.
    #     rew, idx = self.compute_nn_reward(z_obs['visual'])

    #     # nn_image = self.demo_images['sideview_image'][idx]
    #     # nn_image = self.wm_normalizer['sideview_image'].unnormalize(nn_image)
    #     # img_tensor = nn_image.squeeze(0).squeeze(0)  # shape becomes (3, 128, 128)
    #     # save_image(img_tensor, f'debug_image_{self.idx}.png')

    #     cost = -1 * rew
        
    #     # Define loss as mean cost.
    #     loss = cost.mean()
        
    #     # Compute the gradient of the loss with respect to the sample.
    #     grad = torch.autograd.grad(loss, sample, retain_graph=True)[0]
    #     self.idx += 1
    #     return grad
    
    def clean_grad(self):
        for param in self.wm.parameters():
            if param.grad is not None:
                param.grad.zero_()

    def cond_fn(self, sample, current_obs):
        """
        Computes the classifier guidance gradient for one-step rollout.
        
        Args:
            sample (torch.Tensor): normalized noisy action 
            current_obs (torch.Tensor): unnormalized obs
            
        Returns:
            torch.Tensor: The gradient of the classifier loss with respect to the sample.
        """
        self.clean_grad()
        sample = sample.detach().requires_grad_(True)
        init_actions_normalized = sample[:, 1:1+self.horizon * self.frameskip]
        init_actions_unnormalized = self.policy_action_normalizer.unnormalize(init_actions_normalized)
        init_actions = self.wm_normalizer['act'].normalize(init_actions_unnormalized)
        action_batch = rearrange(init_actions, 'b (h f) a -> b h (f a)', f=self.frameskip, h=self.horizon)
        batch_size = init_actions.shape[0]
        current_obs_wm = self.prepare_obs(current_obs, batch_size)
        
        act_0 = action_batch[:, :1, :]  # use the first action
        action = action_batch[:, 1:] 
        z = self.wm.encode(current_obs_wm, act_0)

        total_rew = torch.zeros(batch_size, device=sample.device, dtype=torch.float)
        t = 0
        inc = 1

        # collect latent predictions in a list (safer approach)
        latent_list = [z]

        while t < action.shape[1]:
            # print('latent_list[-1] ', latent_list[-1].shape)
            z_pred = self.wm.predict(latent_list[-1])
            z_new = z_pred[:, -inc:, ...]
            z_obs, _ = self.wm.separate_emb(z_new)
            rew, idx = self.compute_nn_reward(z_obs['visual'])
            print('future cost 1', -1 * rew.item())
            # nn_image = self.demo_images['sideview_image'][idx]
            # nn_image = self.wm_normalizer['sideview_image'].unnormalize(nn_image)
            # img_tensor = nn_image.squeeze(0).squeeze(0)
            # save_image(img_tensor, f'debug_image_0_{self.idx}.png')

            total_rew += rew
            # z_new_updated = self.wm.replace_actions_from_z(z_new, action[:, t : t + inc, :])
            # latent_list.append(z_new_updated)  # append without modifying existing tensors

            t += inc

        # # Next immediate step reward (final step outside loop)
        # z_pred = self.wm.predict(latent_list[-1])
        # z_new = z_pred[:, -1:, ...]
        # z_obs, _ = self.wm.separate_emb(z_new)
        # rew, idx = self.compute_nn_reward(z_obs['visual'])
        # print('future cost 2', -1 * rew.item())
        # # nn_image = self.demo_images['sideview_image'][idx]
        # # nn_image = self.wm_normalizer['sideview_image'].unnormalize(nn_image)
        # # img_tensor = nn_image.squeeze(0).squeeze(0)  # shape becomes (3, 128, 128)
        # # save_image(img_tensor, f'debug_image_1_{self.idx}.png')
        
        # total_rew += rew
        
        cost = -1 * total_rew



        # ############################################
        # z = self.wm.encode(current_obs_wm, act_0)
        # # One-step rollout of the dynamics model.
        # z_pred = self.wm.predict(z[:, -1:])
        # z_new = z_pred[:, -1:, ...]
    
        # z_obs, _ = self.wm.separate_emb(z_new)
        # # Compute cost in chunks; demo_visual_latents is kept on CPU.
        # rew, idx = self.compute_nn_reward(z_obs['visual'])

        # # nn_image = self.demo_images['sideview_image'][idx]
        # # nn_image = self.wm_normalizer['sideview_image'].unnormalize(nn_image)
        # # img_tensor = nn_image.squeeze(0).squeeze(0)  # shape becomes (3, 128, 128)
        # # save_image(img_tensor, f'debug_image_{self.idx}.png')
        # cost = -1 * rew
        # ############################################
        # Define loss as mean cost.
        loss = cost.mean()
        
        # Compute the gradient of the loss with respect to the sample.
        grad = torch.autograd.grad(loss, sample, retain_graph=True)[0]
        print("Gradient norm:", grad.norm().item())

        self.idx += 1
        # sample.detach_()
        return grad
    
    def compute_loss(self, sample, current_obs):
        init_actions_normalized = sample[:, 1:1+self.horizon * self.frameskip]
        print('init_actions_normalized ', init_actions_normalized.shape)
        init_actions_unnormalized = self.policy_action_normalizer.unnormalize(init_actions_normalized)
        init_actions = self.wm_normalizer['act'].normalize(init_actions_unnormalized)
        action_batch = rearrange(init_actions, 'b (h f) a -> b h (f a)', f=self.frameskip, h=self.horizon)
        batch_size = init_actions.shape[0]
        current_obs_wm = self.prepare_obs(current_obs, batch_size)
        
        act_0 = action_batch[:, :1, :]  # use the first action
        # print('act_0 ', act_0.shape)
        action = action_batch[:, 1:] 
        z = self.wm.encode(current_obs_wm, act_0)

        total_rew = torch.zeros(batch_size, device=sample.device, dtype=torch.float)
        t = 0
        inc = 1

        # collect latent predictions in a list (safer approach)
        latent_list = [z]

        while t < 1:
            # print('latent_list[-1] ', latent_list[-1].shape)
            z_pred = self.wm.predict(latent_list[-1])
            z_new = z_pred[:, -inc:, ...]
            z_obs, _ = self.wm.separate_emb(z_new)
            rew, idx = self.compute_nn_reward(z_obs['visual'])
            print('future cost 1', -1 * rew.item())
            total_rew += rew
            # z_new_updated = self.wm.replace_actions_from_z(z_new, action[:, t : t + inc, :])
            # latent_list.append(z_new_updated)  # append without modifying existing tensors

            t += inc

        # # Next immediate step reward (final step outside loop)
        # z_pred = self.wm.predict(latent_list[-1])
        # z_new = z_pred[:, -1:, ...]
        # z_obs, _ = self.wm.separate_emb(z_new)
        # rew, idx = self.compute_nn_reward(z_obs['visual'])
        # print('future cost 2', -1 * rew.item())
        # # nn_image = self.demo_images['sideview_image'][idx]
        # # nn_image = self.wm_normalizer['sideview_image'].unnormalize(nn_image)
        # # img_tensor = nn_image.squeeze(0).squeeze(0)  # shape becomes (3, 128, 128)
        # # save_image(img_tensor, f'debug_image_1_{self.idx}.png')
        
        # total_rew += rew
        
        cost = -1 * total_rew



        # ############################################
        # z = self.wm.encode(current_obs_wm, act_0)
        # # One-step rollout of the dynamics model.
        # z_pred = self.wm.predict(z[:, -1:])
        # z_new = z_pred[:, -1:, ...]
    
        # z_obs, _ = self.wm.separate_emb(z_new)
        # # Compute cost in chunks; demo_visual_latents is kept on CPU.
        # rew, idx = self.compute_nn_reward(z_obs['visual'])

        # # nn_image = self.demo_images['sideview_image'][idx]
        # # nn_image = self.wm_normalizer['sideview_image'].unnormalize(nn_image)
        # # img_tensor = nn_image.squeeze(0).squeeze(0)  # shape becomes (3, 128, 128)
        # # save_image(img_tensor, f'debug_image_{self.idx}.png')
        # cost = -1 * rew
        # ############################################
        # Define loss as mean cost.
        loss = cost.mean()
        self.idx += 1
        # if self.idx == 100: exit()
        return loss


    def optimize_gd(self, current_action, current_obs, frame_assembled_success=False):
        # current_obs unnormalized
        # current_action unnormalized

        # current_action = current_action.detach().requires_grad_(True)
        current_cost = -1 * self.compute_current_reward(current_obs)[0]
        print('current_cost ', current_cost.item())
        if current_cost.item() < 3.2:
            return current_action
        
        init_actions = self.wm_normalizer['act'](current_action)
        init_actions = rearrange(init_actions, 'b (h f) a -> b h (f a)', f=self.frameskip, h=1)
        init_actions = init_actions.clone().detach().requires_grad_(True)
        batch_size = init_actions.shape[0]

        optimizer = optim.Adam([init_actions], lr=3e-1)
        current_obs_wm = self.prepare_obs(current_obs, batch_size)

        for optim_iter in range(20):
            optimizer.zero_grad()
            z = self.wm.encode(current_obs_wm, init_actions)
            z_pred = self.wm.predict(z[:, -1:])
            z_new = z_pred[:, -1:, ...]            
            z_obs, _ = self.wm.separate_emb(z_new)
            rew, idx = self.compute_nn_reward(z_obs['visual'])
            cost = -1 * rew
            # Define loss as mean cost.
            loss = cost.mean()
            loss.backward()
            optimizer.step()
            print('optim_iter ', optim_iter, 'loss ', loss.item())

        init_actions = rearrange(init_actions, 'b h (f a) -> b (h f) a', f=self.frameskip, h=1)
        init_actions = self.wm_normalizer['act'].unnormalize(init_actions)
        return init_actions
    
    def prepare_obs(self, current_obs, action_shape):
        # for key in current_obs.keys():
        #     print(key, current_obs[key].shape)
        proprio = current_obs['agent_pos'].to('cuda')

        visual = {'image': current_obs['image'].to('cuda')}
        for view_name in self.view_names:
            if self.use_resnet_encoder:
                visual[view_name] = self.wm_normalizer[view_name].normalize(visual[view_name].to('cuda'))
            visual[view_name] = self.img_transform(visual[view_name].view(-1, 3, 140, 140))
            if self.use_resnet_encoder:
                visual[view_name] = visual[view_name].view(-1, 1, 3, 128, 128)
            else:
                visual[view_name] = visual[view_name].view(-1, 1, 3, 140, 140)

            # if 'sideview' in view_name:
            #     data0 = visual[view_name][0]
            #     img_t = data0[0]
            #     # img_t1 = data0[1]
            #     # print(f"img_t shape: {img_t.shape}, img_t1 shape: {img_t1.shape}")
            #     print("img_t min/max:", img_t.min().item(), img_t.max().item())
            #     print(os.path.join(os.getcwd(), f"img_t_val.png"))
            #     utils.save_image(img_t, os.path.join(os.getcwd(), f"img_t_val.png"), normalize=True, value_range=(-1, 1))

            # exit()

        current_obs_wm = {'visual': visual, 'proprio': proprio}
        current_obs_wm['proprio'] = self.wm_normalizer['state'].normalize(current_obs_wm['proprio'])

        current_obs_wm['proprio'] = current_obs_wm['proprio'].expand(action_shape, -1, -1)
        current_obs_wm['visual'] = {key: value.expand(action_shape, -1, -1, -1, -1) for key, value in current_obs_wm['visual'].items()}
        return current_obs_wm
    
  