import os
from pathlib import Path 
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.dataset.robomimic_replay_image_dataset import _convert_actions, undo_transform_action
from diffusion_policy.env_runner.robomimic_image_runner import create_env
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
from dino_wm.datasets.robomimic_dset import RobomimicImageDynamicsModelDataset
import robomimic.utils.file_utils as FileUtils
from diffusion_policy.dataset.base_dataset import BaseImageDataset
from diffusion_policy.model.common.normalizer import LinearNormalizer
from dino_wm.datasets.img_transforms import default_transform, get_eval_crop_transform, get_eval_crop_transform_resnet
from dino_wm.plan import load_model
import torch.optim as optim
import torch
from torch import nn
from torch.utils.data import DataLoader
from accelerate import Accelerator

import hydra
from omegaconf import OmegaConf, open_dict
from hydra.utils import instantiate
import copy
import numpy as np
import torch.linalg as linalg
import torch.nn.functional as F
from einops import rearrange
from torchvision import utils
from torchvision.utils import save_image
from torchvision.transforms.functional import to_pil_image, to_tensor
from PIL import ImageDraw, ImageFont

def add_text_to_image(img_tensor, text):
    img_min, img_max = img_tensor.min(), img_tensor.max()
    img_tensor = (img_tensor - img_min) / (img_max - img_min)
    img = to_pil_image(img_tensor)
    draw = ImageDraw.Draw(img)    
    width, height = img.size
    text_pos = (width - 50, 50)
    try:
        font = ImageFont.truetype("arial.ttf", 40)
    except IOError:
        font = ImageFont.load_default()
    draw.text(text_pos, f"{text:.2f}", fill="black", font=font)
    img_tensor = to_tensor(img)
    img_tensor = img_tensor * (img_max - img_min) + img_min
    
    return img_tensor


class Planner:
    def __init__(
        self,
        demo_dataset_config,
        dynamics_model_ckpt,
        decoder_path,
        value_func_path=None,
        action_step=8,
        output_dir='debug/'

    ):
        self.accelerator = Accelerator()
        self.device = self.accelerator.device

        # demo dataset
        self.demo_dataset_config = demo_dataset_config
        # wm
        dynamics_model_dir = os.path.dirname(os.path.dirname(dynamics_model_ckpt))
        with open(os.path.join(dynamics_model_dir, "hydra.yaml"), "r") as f:
            model_cfg = OmegaConf.load(f)
            assert model_cfg.abs_action == self.demo_dataset_config.abs_action
        self.wm = load_model(Path(dynamics_model_ckpt), model_cfg, 1, device=self.device)
        if not model_cfg.model.train_encoder and model_cfg.encoder_ckpt_path is not None:
            encoder_ckpt = torch.load(model_cfg.encoder_ckpt_path, map_location='cuda')
            self.wm.encoder.load_state_dict(encoder_ckpt['encoder'])
            print('loaded encoder from ', model_cfg.encoder_ckpt_path)

        ###### temp, for debugging #####
        if decoder_path is not None:
            decoder_ckpt = torch.load(decoder_path, map_location='cuda')
            self.wm.decoder = hydra.utils.instantiate(
                model_cfg.decoder, 
                emb_dim=self.wm.encoder.emb_dim
            )
            self.wm.decoder.load_state_dict(decoder_ckpt['decoder'])
            print('loaded decoder from ', decoder_path)

        self.wm = self.accelerator.prepare(self.wm)
        self.wm.disable_reconstruction = False
        self.wm.eval()

        if value_func_path is not None:
            value_net = ConvValueNetwork(emb_dim_in=768).to('cuda')
            value_net_checkpoint = torch.load(value_func_path, map_location='cuda')
            value_net.load_state_dict(value_net_checkpoint['model_state_dict'])
            self.value_net = self.accelerator.prepare(value_net)
            self.value_net.eval()
        else:
            self.value_net = None
        # normalizer
        wm_normalizer = LinearNormalizer()
        wm_normalizer.load_state_dict(torch.load(os.path.join(dynamics_model_dir, "normalizer.pth")))
        self.wm_normalizer = wm_normalizer.to(self.device)
        self.policy_action_normalizer = LinearNormalizer()

        # eval_env
        self.abs_action = self.demo_dataset_config.abs_action
        shape_meta = demo_dataset_config.shape_meta
        env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=demo_dataset_config.dataset_path)
        env_meta['env_kwargs']['controller_configs']['control_delta'] = not self.abs_action
        self.eval_env = create_env(
            env_meta=env_meta, 
            shape_meta=shape_meta
        )
        self.original_action_dim = 10
        if "Transport" in env_meta["env_name"]:
            self.original_action_dim = 20
            
        self.use_crop = model_cfg.use_crop
        self.use_resnet_encoder = model_cfg.use_resnet_encoder
        self.original_img_size = 140
        self.transformed_img_size = 128
        if self.use_crop:
            if self.use_resnet_encoder:
                self.img_transform = get_eval_crop_transform_resnet(self.original_img_size)
            else:
                self.img_transform = get_eval_crop_transform(self.original_img_size)

        self.view_names = model_cfg.view_names

        self.env_name = env_meta["env_name"]
        print('env_name ', self.env_name)

        self.frameskip = model_cfg.frameskip
        self.action_dim = self.demo_dataset_config.shape_meta.action.shape[0]
        self.exec_step = action_step
        print('self.demo_dataset_config.horizon ', self.demo_dataset_config.horizon)
        if self.demo_dataset_config.horizon == 32:
            self.horizon = 2
        elif self.demo_dataset_config.horizon == 16:
            self.horizon = 1

        self.get_demo_latents()

        train_action_stats = torch.load(os.path.join(dynamics_model_dir, "train_action_stats.pth"))
        self.training_action_mean = train_action_stats['action_mean']
        self.training_action_std = train_action_stats['action_std']
        self.training_action_max = train_action_stats['action_max']
        self.training_action_min = train_action_stats['action_min']

        if "Nut" in env_meta["env_name"]:
            with open("square_model_file_1.4.xml", 'r') as f:
                print('reading square_model_file_1.4 in planner')
                self.model_file = f.read()
        elif "Transport" in env_meta["env_name"]:
            with open("transport_model_file_1.4.xml", 'r') as f:
                print('reading transport_model_file_1.4 in planner')
                self.model_file = f.read()
        elif 'ToolHang' in env_meta["env_name"]:
            with open("tool_hang_model_file_1.4.xml", 'r') as f:
                print('reading tool_hang_model_file_1.4 in planner')
                self.model_file = f.read()

        self.timestep = 0
        self.rotation_transformer = RotationTransformer(
                from_rep='axis_angle', to_rep='rotation_6d')
        self.output_dir = output_dir
        self.image_output_dir = os.path.join(output_dir, 'images')
        os.makedirs(self.image_output_dir, exist_ok=True)
        self.idx = 0

    def set_policy_action_normalizer(self, policy_action_normalizer):
        self.policy_action_normalizer = policy_action_normalizer

    def get_demo_latents(self,):
        demo_dataset: BaseImageDataset
        if 'Nut' in self.env_name:
            self.demo_dataset_config.dataset_path = 'path_to_expert_dataset'
            self.demo_dataset_config.val_ratio = 0.0
        
        elif 'ToolHang' in self.env_name:
            self.demo_dataset_config.dataset_path = "path_to_expert_dataset"

        elif 'Transport' in self.env_name:
            self.demo_dataset_config.dataset_path = "path_to_expert_dataset"
            self.demo_dataset_config.val_ratio = 0.0

        self.demo_dataset_config.horizon = 1
        self.demo_dataset_config.n_obs_steps = 1
        self.demo_dataset_config.pad_before = 0
        self.demo_dataset_config.pad_after = 0

        demo_dataset = hydra.utils.instantiate(self.demo_dataset_config)
        print('len(demo_dataset) ', len(demo_dataset))
        demo_loader = DataLoader(demo_dataset, batch_size=64, shuffle=False, num_workers=4)

        demo_visual_latents = []
        demo_proprio_latents = []
        demo_images = []
        with torch.no_grad():
            for batch_idx, batch in enumerate(demo_loader):
                obs = batch['obs']
                obs = dict_apply(obs, lambda x: x[:, -1:, ...])
                robot0_eef_pos = obs['robot0_eef_pos']
                robot0_eef_quat = obs['robot0_eef_quat']
                robot0_gripper_qpos = obs['robot0_gripper_qpos']

                proprio = torch.cat([robot0_eef_pos, robot0_eef_quat, robot0_gripper_qpos], dim=-1).to(self.device)
                
                if 'Transport' in self.env_name:
                    robot1_eef_pos = obs['robot1_eef_pos'].to(self.device)
                    robot1_eef_quat = obs['robot1_eef_quat'].to(self.device)
                    robot1_gripper_qpos = obs['robot1_gripper_qpos'].to(self.device)
                    proprio = torch.cat([proprio, robot1_eef_pos, robot1_eef_quat, robot1_gripper_qpos], dim=-1)

                visual = {}

                for view_name in self.view_names:
                    # print('visual_name ', view_name)
                    bs, h = obs[view_name].shape[:2]
                    visual[view_name] = obs[view_name].to(self.device)
                    if self.use_resnet_encoder:
                        visual[view_name] = self.wm_normalizer[view_name].normalize(visual[view_name])
                    visual[view_name] = self.img_transform(visual[view_name].view(-1, 3, 140, 140))
                    if self.use_resnet_encoder:
                        visual[view_name] = visual[view_name].view(-1, 1, 3, 128, 128)
                    else:
                        visual[view_name] = visual[view_name].view(-1, 1, 3, 140, 140)
                visual_cpu = {k: v.cpu() for k, v in visual.items()}
                demo_images.append(visual_cpu)

                obs_wm = {'visual': visual, 'proprio': proprio}
                obs_wm['proprio'] = self.wm_normalizer['state'].normalize(obs_wm['proprio'])

                encode_obs = self.wm.encode_obs(obs_wm)
                demo_visual_latents.append(encode_obs['visual'].cpu())
                demo_proprio_latents.append(encode_obs['proprio'].cpu())
                torch.cuda.empty_cache()

        self.demo_visual_latents = torch.cat(demo_visual_latents, dim=0)
        self.demo_proprio_latents = torch.cat(demo_proprio_latents, dim=0)

        if len(self.demo_visual_latents.shape) > 2:
            self.demo_visual_latents = self.demo_visual_latents.reshape(self.demo_visual_latents.size(0), -1)
            if 'ToolHang' in self.env_name:
                self.demo_visual_latents = self.demo_visual_latents[...,512:]
            elif 'Transport' in self.env_name:
                self.demo_visual_latents = self.demo_visual_latents[...,:1024]

        print('demo_visual_latents shape ', self.demo_visual_latents.shape)
        print('demo_proprio_latents shape ', self.demo_proprio_latents.shape)
        self.demo_images = {
            key: torch.cat([d[key] for d in demo_images], dim=0)
            for key in demo_images[0].keys()
        }
        for key in self.demo_images.keys():
            print(key, self.demo_images[key].shape)

        del demo_dataset

    # deprecated
    def init_mu_sigma(self, actions):
        # actions: (1, h=8, action_dim)
        if actions is not None:
            # actions = actions.expand(self.num_samples, -1, -1)
            actions = self.wm_normalizer['act'](actions[:, :self.horizon * self.frameskip])

        # sigma = 1.0 * torch.ones([self.num_samples, self.horizon * self.frameskip, self.action_dim])
        sigma = 1.5 * self.training_action_std.unsqueeze(0).unsqueeze(0)
        sigma = sigma.expand(self.num_samples, self.horizon * self.frameskip, self.action_dim)

        if actions is None:
            # mu = torch.zeros(self.num_samples, 0, self.action_dim)
            mu = self.training_action_mean.unsqueeze(0).unsqueeze(0)  # shape (1,1,action_dim)
            mu = mu.expand(self.num_samples, 1, self.action_dim)
        else:
            mu = actions
        t = mu.shape[1]
        remaining_t = self.horizon * self.frameskip - t
        if remaining_t > 0:
            new_mu = self.training_action_mean.unsqueeze(0).unsqueeze(0)
            new_mu = new_mu.expand(self.num_samples, remaining_t, self.action_dim)
            mu = torch.cat([mu, new_mu], dim=1)
        return mu, sigma

    def compute_rollout_reward(self, current_obs, action_batch):
        # current_obs_wm: (1,1,3,224,224), (1,1,9), unnormalized
        # action_batch: (bs, horizon * frameskip, action_dim), normalized
        with torch.no_grad():
            current_obs_wm = self.prepare_obs(current_obs, action_batch.shape[0])
            action_batch = rearrange(action_batch, 'b (h f) a -> b h (f a)', f=self.frameskip, h=self.horizon)
            act_0 = action_batch[:, :1] 
            action = action_batch[:, 1:] 
            z = self.wm.encode(current_obs_wm, act_0)

            t = 0
            inc = 1
            while t < action.shape[1]:
                z_pred = self.wm.predict(z[:, -1:])
                z_new = z_pred[:, -inc:, ...]
                z_obs, _ = self.wm.separate_emb(z_new)
                rew, idx = self.compute_nn_reward(z_obs['visual'])

                # z_new = self.wm.replace_actions_from_z(z_new, action[:, t : t + inc, :])
                # z = torch.cat([z, z_new], dim=1)
                # t += inc
                break
            # NOTE:rew on next immediate step?
            # z_pred = self.wm.predict(z[:, -1:])
            # z_new = z_pred[:, -1 :, ...] # take only the next pred
            # z_obs, z_act = self.wm.separate_emb(z_new)
            # if self.value_net is None:
            #     reward, _ = self.compute_nn_reward(z_obs['visual'])
            # else:
            #     reward = self.value_net(z_obs['visual'].squeeze(1))
        return rew
    
    # deprecated
    def compute_current_reward(self, current_obs):
        with torch.no_grad():
            current_obs_wm = self.prepare_obs(current_obs, 1)
            encode_obs = self.wm.encode_obs(current_obs_wm)
            current_visual_latent = encode_obs['visual'] # (1, 1, 196, 382*2)
            reward, _ = self.compute_nn_reward(current_visual_latent.squeeze(1))
        return reward

    def compute_nn_reward(self, current_visual_latent):
        if len(current_visual_latent.shape) > 2:
            current_visual_latent = current_visual_latent.reshape(current_visual_latent.size(0), -1)
            if 'ToolHang' in self.env_name:
                current_visual_latent = current_visual_latent[..., 512:]
            elif 'Transport' in self.env_name:
                current_visual_latent = current_visual_latent[...,:1024]
            elif 'Nut' in self.env_name:
                weights = torch.cat([
                    torch.full((512,), 2, device=current_visual_latent.device),
                    torch.full((512,), 0.5, device=current_visual_latent.device)
                ])
                current_visual_latent = current_visual_latent * weights

        device = current_visual_latent.device
        chunk_size = 2048  # Adjust this based on your GPU memory constraints.
        batch_size = current_visual_latent.size(0)
        
        global_min_cost = None
        global_min_idx = None
        
        # TODO: try softmax??
        # Process the demo_visual_latents in chunks.
        for start in range(0, self.demo_visual_latents.shape[0], chunk_size):
            demo_chunk = self.demo_visual_latents[start:start+chunk_size].to(device, non_blocking=True)
            # Compute pairwise distances between current_visual_latent and the current chunk.
            # if not self.use_resnet_encoder:
            if 'Nut' in self.env_name:
                demo_chunk = demo_chunk * weights
            dist = torch.cdist(current_visual_latent, demo_chunk, p=2)  # shape: (B, chunk_size)
            cost, idx = dist.min(dim=-1)  # cost: (B,), idx: (B,)
            # else:
            #     diff = current_visual_latent.unsqueeze(1) - demo_chunk.unsqueeze(0)  # Shape: (B, chunk_size, latent_dim)
            #     mse_dist = (diff ** 2).mean(dim=-1)  # MSE loss shape: (B, chunk_size)
            #     cost, idx = mse_dist.min(dim=-1)  # cost: (B,), idx: (B,)
            
            if global_min_cost is None:
                global_min_cost = cost
                global_min_idx = idx + start  # Adjust index offset.
            else:
                # Update the global minimum for each sample.
                mask = cost < global_min_cost
                global_min_cost[mask] = cost[mask]
                global_min_idx[mask] = idx[mask] + start
        
        reward = -global_min_cost
        # nearest_latent = self.demo_visual_latents[global_min_idx]
        # print('mse loss ', F.mse_loss(current_visual_latent, nearest_latent.to(device)).item())
        return reward, global_min_idx

    def plan_shooting(self, current_obs, current_state, init_actions, eval_id=None, timestep=None):
        # init_actions is unnormalized, current_obs is unnormalized
        self.timestep += 1
        init_actions = self.wm_normalizer['act'](init_actions[:, :self.horizon * self.frameskip])
        print('1 init_actions ', init_actions.shape)
        num_samples = init_actions.shape[0]
        bs = 8
        
        assert num_samples == bs
        value = torch.zeros((num_samples), device=self.device)
        for batch in range(num_samples // bs):
            # print('batch*bs:(batch+1)*bs ', batch*bs, (batch+1)*bs)
            action_batch = init_actions[batch*bs:(batch+1)*bs, :, :]
            reward = self.compute_rollout_reward(
                current_obs, action_batch
            )
            value[batch*bs:(batch+1)*bs] = reward
            torch.cuda.empty_cache()
        
        # value = self.compute_rollout_reward(current_obs, init_actions)
        best_value, best_index = value.max(0)
        print('best_value ', best_value.item())
        best_action = init_actions[best_index]
        print('best_action ', best_action.shape)
        # _, _ = self.eval_actions(current_obs, current_state, best_action, timestep=timestep, eval_id=eval_id)
        
        best_action_to_execute = best_action[:self.exec_step]
        best_action_to_execute_unnormalized = self.wm_normalizer['act'].unnormalize(best_action_to_execute)
        return best_action_to_execute_unnormalized.unsqueeze(0), best_index

    def clean_grad(self):
        for param in self.wm.parameters():
            if param.grad is not None:
                param.grad.zero_()

    def cond_fn(self, sample, current_obs):
        """
        Computes the classifier guidance gradient for one-step rollout.
        
        Args:
            sample (torch.Tensor): normalized noisy action 
            current_obs (torch.Tensor): unnormalized obs
            
        Returns:
            torch.Tensor: The gradient of the classifier loss with respect to the sample.
        """
        self.clean_grad()
        sample = sample.detach().requires_grad_(True)
        init_actions_normalized = sample[:, 1:1+self.horizon * self.frameskip]
        init_actions_unnormalized = self.policy_action_normalizer.unnormalize(init_actions_normalized)
        init_actions = self.wm_normalizer['act'].normalize(init_actions_unnormalized)
        action_batch = rearrange(init_actions, 'b (h f) a -> b h (f a)', f=self.frameskip, h=self.horizon)
        batch_size = init_actions.shape[0]
        current_obs_wm = self.prepare_obs(current_obs, batch_size)
        
        act_0 = action_batch[:, :1, :]  # use the first action
        action = action_batch[:, 1:] 
        z = self.wm.encode(current_obs_wm, act_0)

        total_rew = torch.zeros(batch_size, device=sample.device, dtype=torch.float)
        t = 0
        inc = 1

        # collect latent predictions in a list (safer approach)
        latent_list = [z]

        while t < action.shape[1]:
            # print('latent_list[-1] ', latent_list[-1].shape)
            z_pred = self.wm.predict(latent_list[-1])
            z_new = z_pred[:, -inc:, ...]
            z_obs, _ = self.wm.separate_emb(z_new)
            rew, idx = self.compute_nn_reward(z_obs['visual'])
            print('future cost 1', -1 * rew.item())
            # nn_image = self.demo_images['sideview_image'][idx]
            # nn_image = self.wm_normalizer['sideview_image'].unnormalize(nn_image)
            # img_tensor = nn_image.squeeze(0).squeeze(0)
            # save_image(img_tensor, f'debug_image_0_{self.idx}.png')

            total_rew += rew
            # z_new_updated = self.wm.replace_actions_from_z(z_new, action[:, t : t + inc, :])
            # latent_list.append(z_new_updated)  # append without modifying existing tensors

            t += inc

        # # Next immediate step reward (final step outside loop)
        # z_pred = self.wm.predict(latent_list[-1])
        # z_new = z_pred[:, -1:, ...]
        # z_obs, _ = self.wm.separate_emb(z_new)
        # rew, idx = self.compute_nn_reward(z_obs['visual'])
        # print('future cost 2', -1 * rew.item())
        # # nn_image = self.demo_images['sideview_image'][idx]
        # # nn_image = self.wm_normalizer['sideview_image'].unnormalize(nn_image)
        # # img_tensor = nn_image.squeeze(0).squeeze(0)  # shape becomes (3, 128, 128)
        # # save_image(img_tensor, f'debug_image_1_{self.idx}.png')
        
        # total_rew += rew
        
        cost = -1 * total_rew



        # ############################################
        # z = self.wm.encode(current_obs_wm, act_0)
        # # One-step rollout of the dynamics model.
        # z_pred = self.wm.predict(z[:, -1:])
        # z_new = z_pred[:, -1:, ...]
    
        # z_obs, _ = self.wm.separate_emb(z_new)
        # # Compute cost in chunks; demo_visual_latents is kept on CPU.
        # rew, idx = self.compute_nn_reward(z_obs['visual'])

        # # nn_image = self.demo_images['sideview_image'][idx]
        # # nn_image = self.wm_normalizer['sideview_image'].unnormalize(nn_image)
        # # img_tensor = nn_image.squeeze(0).squeeze(0)  # shape becomes (3, 128, 128)
        # # save_image(img_tensor, f'debug_image_{self.idx}.png')
        # cost = -1 * rew
        # ############################################
        # Define loss as mean cost.
        loss = cost.mean()
        
        # Compute the gradient of the loss with respect to the sample.
        grad = torch.autograd.grad(loss, sample, retain_graph=True)[0]
        print("Gradient norm:", grad.norm().item())

        self.idx += 1
        # sample.detach_()
        return grad
    
    def compute_loss(self, sample, current_obs):
        init_actions_normalized = sample[:, 1:1+self.horizon * self.frameskip]
        init_actions_unnormalized = self.policy_action_normalizer.unnormalize(init_actions_normalized)
        init_actions = self.wm_normalizer['act'].normalize(init_actions_unnormalized)
        action_batch = rearrange(init_actions, 'b (h f) a -> b h (f a)', f=self.frameskip, h=self.horizon)
        batch_size = init_actions.shape[0]
        current_obs_wm = self.prepare_obs(current_obs, batch_size)
        
        act_0 = action_batch[:, :1, :]  # use the first action
        # print('act_0 ', act_0.shape)
        action = action_batch[:, 1:] 
        z = self.wm.encode(current_obs_wm, act_0)

        total_rew = torch.zeros(batch_size, device=sample.device, dtype=torch.float)
        t = 0
        inc = 1

        # collect latent predictions in a list (safer approach)
        latent_list = [z]

        while t < 1:
            # print('latent_list[-1] ', latent_list[-1].shape)
            z_pred = self.wm.predict(latent_list[-1])
            z_new = z_pred[:, -inc:, ...]
            z_obs, _ = self.wm.separate_emb(z_new)
            rew, idx = self.compute_nn_reward(z_obs['visual'])
            print('future cost 1', -1 * rew.item())
            total_rew += rew
            # z_new_updated = self.wm.replace_actions_from_z(z_new, action[:, t : t + inc, :])
            # latent_list.append(z_new_updated)  # append without modifying existing tensors

            t += inc

        # # Next immediate step reward (final step outside loop)
        # z_pred = self.wm.predict(latent_list[-1])
        # z_new = z_pred[:, -1:, ...]
        # z_obs, _ = self.wm.separate_emb(z_new)
        # rew, idx = self.compute_nn_reward(z_obs['visual'])
        # print('future cost 2', -1 * rew.item())
        # # nn_image = self.demo_images['sideview_image'][idx]
        # # nn_image = self.wm_normalizer['sideview_image'].unnormalize(nn_image)
        # # img_tensor = nn_image.squeeze(0).squeeze(0)  # shape becomes (3, 128, 128)
        # # save_image(img_tensor, f'debug_image_1_{self.idx}.png')
        
        # total_rew += rew
        
        cost = -1 * total_rew



        # ############################################
        # z = self.wm.encode(current_obs_wm, act_0)
        # # One-step rollout of the dynamics model.
        # z_pred = self.wm.predict(z[:, -1:])
        # z_new = z_pred[:, -1:, ...]
    
        # z_obs, _ = self.wm.separate_emb(z_new)
        # # Compute cost in chunks; demo_visual_latents is kept on CPU.
        # rew, idx = self.compute_nn_reward(z_obs['visual'])

        # # nn_image = self.demo_images['sideview_image'][idx]
        # # nn_image = self.wm_normalizer['sideview_image'].unnormalize(nn_image)
        # # img_tensor = nn_image.squeeze(0).squeeze(0)  # shape becomes (3, 128, 128)
        # # save_image(img_tensor, f'debug_image_{self.idx}.png')
        # cost = -1 * rew
        # ############################################
        # Define loss as mean cost.
        loss = cost.mean()
        self.idx += 1
        # if self.idx == 100: exit()
        return loss


    def optimize_gd(self, current_action, current_obs, frame_assembled_success=False):
        # current_obs unnormalized
        # current_action unnormalized

        # current_action = current_action.detach().requires_grad_(True)
        current_cost = -1 * self.compute_current_reward(current_obs)[0]
        print('current_cost ', current_cost.item())
        if current_cost.item() < 1.4:
            return current_action
        
        init_actions = self.wm_normalizer['act'](current_action)
        init_actions = rearrange(init_actions, 'b (h f) a -> b h (f a)', f=self.frameskip, h=1)
        init_actions = init_actions.clone().detach().requires_grad_(True)
        batch_size = init_actions.shape[0]

        optimizer = optim.Adam([init_actions], lr=1e-4)
        current_obs_wm = self.prepare_obs(current_obs, batch_size)

        for optim_iter in range(20):
            optimizer.zero_grad()
            z = self.wm.encode(current_obs_wm, init_actions)
            z_pred = self.wm.predict(z[:, -1:])
            z_new = z_pred[:, -1:, ...]            
            z_obs, _ = self.wm.separate_emb(z_new)
            rew, idx = self.compute_nn_reward(z_obs['visual'])
            cost = -1 * rew
            # Define loss as mean cost.
            loss = cost.mean()
            loss.backward()
            optimizer.step()
            print('optim_iter ', optim_iter, 'loss ', loss.item())

        init_actions = rearrange(init_actions, 'b h (f a) -> b (h f) a', f=self.frameskip, h=1)
        init_actions = self.wm_normalizer['act'].unnormalize(init_actions)
        return init_actions
    
    def prepare_obs(self, current_obs, action_shape):
        # for key in current_obs.keys():
        #     print(key, current_obs[key].shape)
        robot0_eef_pos = current_obs['robot0_eef_pos']
        robot0_eef_quat = current_obs['robot0_eef_quat']
        robot0_gripper_qpos = current_obs['robot0_gripper_qpos']

        proprio = torch.cat([robot0_eef_pos, robot0_eef_quat, robot0_gripper_qpos], dim=-1).to('cuda')

        if 'Transport' in self.env_name:
            robot1_eef_pos = current_obs['robot1_eef_pos']
            robot1_eef_quat = current_obs['robot1_eef_quat']
            robot1_gripper_qpos = current_obs['robot1_gripper_qpos']
            proprio = torch.cat([proprio, robot1_eef_pos, robot1_eef_quat, robot1_gripper_qpos], dim=-1).to('cuda')

        visual = {}
        for view_name in self.view_names:
            bs, h = current_obs[view_name].shape[:2]
            visual[view_name] = current_obs[view_name].to('cuda')
            if self.use_resnet_encoder:
                visual[view_name] = self.wm_normalizer[view_name].normalize(visual[view_name].to('cuda'))
            visual[view_name] = self.img_transform(visual[view_name].view(-1, 3, 140, 140))
            if self.use_resnet_encoder:
                visual[view_name] = visual[view_name].view(-1, 1, 3, 128, 128)
            else:
                visual[view_name] = visual[view_name].view(-1, 1, 3, 140, 140)

            # if 'sideview' in view_name:
            #     data0 = visual[view_name][0]
            #     img_t = data0[0]
            #     # img_t1 = data0[1]
            #     # print(f"img_t shape: {img_t.shape}, img_t1 shape: {img_t1.shape}")
            #     print("img_t min/max:", img_t.min().item(), img_t.max().item())
            #     print(os.path.join(os.getcwd(), f"img_t_val.png"))
            #     utils.save_image(img_t, os.path.join(os.getcwd(), f"img_t_val.png"), normalize=True, value_range=(-1, 1))

            # exit()

        current_obs_wm = {'visual': visual, 'proprio': proprio}
        current_obs_wm['proprio'] = self.wm_normalizer['state'].normalize(current_obs_wm['proprio'])

        current_obs_wm['proprio'] = current_obs_wm['proprio'].expand(action_shape, -1, -1)
        current_obs_wm['visual'] = {key: value.expand(action_shape, -1, -1, -1, -1) for key, value in current_obs_wm['visual'].items()}
        return current_obs_wm
    
    # deprecated
    def plan(self, current_obs, current_state, init_actions=None):
        # current_obs is unnormalized, (1, h=1, 3, 224, 224)
        # init_actions is unnormlized, (1, h=8, action_dim)
        while True:
            current_reward = self.compute_current_reward(current_obs)
            print('current_reward ', current_reward.item())
            if current_reward.item() > - 0.05:
                break
            mu, sigma = self.init_mu_sigma(init_actions)
            mu, sigma = mu.to(self.device), sigma.to(self.device) # (num_samples, horizon * frameskip, action_dim)            
            for i in range(15):
                actions = torch.clamp(mu + sigma * \
                    torch.randn(self.num_samples, self.horizon * self.frameskip, self.action_dim, device=sigma.device), 
                    min=self.training_action_min.view(1, 1, self.action_dim).to(sigma.device), 
                    max=self.training_action_max.view(1, 1, self.action_dim).to(sigma.device))
                # actions = mu
                value = torch.zeros((self.num_samples), device=sigma.device)

                for batch in range(self.num_samples // self.batch_size):
                    action_batch = actions[batch*self.batch_size:(batch+1)*self.batch_size, :, :]
                    reward = self.compute_rollout_reward(
                        current_obs, action_batch
                    )
                    # exit()
                    value[batch*self.batch_size:(batch+1)*self.batch_size] = reward

                elite_idxs = torch.topk(value, self.num_elites, dim=0).indices
                elite_value, elite_actions = value[elite_idxs], actions[elite_idxs, :, :]
                # MPPI
                max_value = elite_value.max(0)[0]
                score = torch.exp(0.5*(elite_value - max_value))
                score /= score.sum(0)
                _mu = torch.sum(score.view(self.num_elites, 1, 1) * elite_actions, dim=0) / (score.sum(0) + 1e-9)
                _sigma = torch.sqrt(torch.sum(score.view(self.num_elites, 1, 1) * (elite_actions - _mu.unsqueeze(0)) ** 2, dim=0) / (score.sum(0) + 1e-9))
                # mu, sigma = 0.01 * mu + (1 - 0.01) * _mu, _sigma
                mu, sigma = _mu, _sigma

                torch.cuda.empty_cache()
            score = score.cpu().numpy()
            elite_action = elite_actions[np.random.choice(np.arange(score.shape[0]), p=score), :, :]
            # elite_action = elite_action[:self.frameskip]
            # elite_action = self.wm_normalizer['action'].unnormalize(elite_action)
            
            next_obs, next_state = self.eval_actions(current_obs, current_state, elite_action, timestep=self.timestep)
            init_actions = None
            current_obs = next_obs
            current_state = next_state
            self.timestep += 1

    def eval_actions(self, current_obs, current_state, elite_action, timestep, eval_id):
        # current_obs is unnormalized, (1, h=1, 3, 224, 224)
        # elite_action is normalized, (h=horizon*frameskip, action_dim)
        nn_idx_array = []
        reward_array = []
        # get decoded images
        with torch.no_grad():
            elite_action = elite_action.unsqueeze(0)
            current_obs_wm = self.prepare_obs(current_obs, elite_action.shape[0])
            action_batch = rearrange(elite_action, 'b (h f) a -> b h (f a)', f=self.frameskip, h=self.horizon)

            act_0 = action_batch[:, :1]
            action = action_batch[:, 1:] 
            z = self.wm.encode(current_obs_wm, act_0)
            visual_latents = []

            z_obs, z_act = self.wm.separate_emb(z)
            rew, idx = self.compute_nn_reward(z_obs['visual'])
            nn_idx_array.append(idx.item())
            reward_array.append(rew.item())
            visual_latents.append(z_obs)

            t = 0
            inc = 1
            while t < action.shape[1]:
                z_pred = self.wm.predict(z[:, -1:])
                z_new = z_pred[:, -inc:, ...]
                z_obs, z_act = self.wm.separate_emb(z_new)
                rew, idx = self.compute_nn_reward(z_obs['visual'])
                nn_idx_array.append(idx.item())
                reward_array.append(rew.item())
                visual_latents.append(z_obs)
                z_new = self.wm.replace_actions_from_z(z_new, action[:, t : t + inc, :])
                z = torch.cat([z, z_new], dim=1)
                t += inc

            z_pred = self.wm.predict(z[:, -1:])
            z_new = z_pred[:, -1 :, ...] # take only the next pred
            z_obs, z_act = self.wm.separate_emb(z_new)
            visual_latents.append(z_obs)

        rew, idx = self.compute_nn_reward(z_obs['visual'])
        nn_idx_array.append(idx.item())
        reward_array.append(rew.item())

        gt_img = {}
        gt_img_complete = {}
        for view_name in self.view_names:
            if view_name not in gt_img:
                gt_img[view_name] = []
                gt_img_complete[view_name] = []
            gt_img[view_name].append(current_obs_wm['visual'][view_name])
            gt_img_complete[view_name].append(current_obs_wm['visual'][view_name])

        
        visual_latents = {key: torch.cat([v[key] for v in visual_latents], dim=0) for key in visual_latents[0].keys()}
        decoded_img = self.wm.decode_obs(visual_latents)[0]['visual']
        decoded_img = {k: v.cpu() for k, v in decoded_img.items()}

        nn_img = {key: self.demo_images[key][nn_idx_array] for key in self.demo_images.keys()}

        action_batch = action_batch.squeeze(0)
        # print('action_batch ', action_batch.shape)
        action_batch = rearrange(action_batch, 'h (f d) -> (h f) d', h=self.horizon, f=self.frameskip)
        # print('action_batch ', action_batch.shape)
        action_exec = action_batch[:self.exec_step, :]
        # print('action_exec ', action_exec.shape)
        action_exec = self.wm_normalizer['act'].unnormalize(action_exec)
        action_exec = action_exec.cpu().detach().numpy()
        if action_exec.shape[-1] == 10:
            action_exec = undo_transform_action(action_exec, self.rotation_transformer)

        curr_state_dict = dict(states=current_state)
        if timestep == 0:
            curr_state_dict["model"] = self.model_file

        obs = self.eval_env.reset_to(curr_state_dict)
        for i in range(action_exec.shape[0]):
            next_obs, _, _, _ = self.eval_env.step(action_exec[i])
            next_state = self.eval_env.env.sim.get_state().flatten()
            next_obs = self.make_torch_obs(next_obs)
            for view_name in self.view_names:
                bs, h = next_obs[view_name].shape[:2]
                gt_img_complete[view_name].append(self.img_transform(next_obs[view_name].to('cuda').view(-1, 3, self.original_img_size, self.original_img_size)).view(bs, h, 3, self.original_img_size, self.original_img_size))
                if (i+1) % self.frameskip == 0:
                    gt_img[view_name].append(self.img_transform(next_obs[view_name].to('cuda').view(-1, 3, self.original_img_size, self.original_img_size)).view(bs, h, 3, self.original_img_size, self.original_img_size))

        remaining_action = action_batch[self.exec_step:, :]
        remaining_action = self.wm_normalizer['act'].unnormalize(remaining_action)

        remaining_action = remaining_action.cpu().detach().numpy()
        if remaining_action.shape[-1] == 10:
            remaining_action = undo_transform_action(remaining_action, self.rotation_transformer)

        for i in range(remaining_action.shape[0]):
            dummy_obs, _, _, _ = self.eval_env.step(remaining_action[i])
            dummy_state = self.eval_env.env.sim.get_state().flatten()
            dummy_obs = self.make_torch_obs(dummy_obs)
            for view_name in self.view_names:
                bs, h = dummy_obs[view_name].shape[:2]
                gt_img_complete[view_name].append(self.img_transform(dummy_obs[view_name].to('cuda').view(-1, 3, self.original_img_size, self.original_img_size)).view(bs, h, 3, self.original_img_size, self.original_img_size))
                if (i+1) % self.frameskip == 0 or i == remaining_action.shape[0] - 1:
                    gt_img[view_name].append(self.img_transform(dummy_obs[view_name].to('cuda').view(-1, 3, self.original_img_size, self.original_img_size)).view(bs, h, 3, self.original_img_size, self.original_img_size))

        gt_img = {key: torch.cat(value, dim=0) for key, value in gt_img.items()}
        gt_img = {k: v.cpu() for k, v in gt_img.items()}
        gt_img_complete = {key: torch.cat(value, dim=0) for key, value in gt_img_complete.items()}

        annotated_decoded_imgs = {}
        for key in decoded_img.keys():
            annotated_decoded_imgs[key] = torch.stack(
                [add_text_to_image(img[0], reward) for img, reward in zip(decoded_img[key], reward_array)]
            )

        for key in gt_img.keys():
            # print('gt_img[key] ', gt_img[key].shape, gt_img[key].device)
            # print('decoded_img[key] ', annotated_decoded_imgs[key].shape, annotated_decoded_imgs[key].device)
            # print('nn_img[key] ', nn_img[key].shape, nn_img[key].device)
            combined = torch.cat([annotated_decoded_imgs[key], nn_img[key].squeeze(1), gt_img[key].squeeze(1)], dim=0)
            utils.save_image(combined, 
                             f'{self.image_output_dir}/epi_{eval_id}_timestep_{timestep}_{key}.png', 
                             nrow=gt_img[self.view_names[0]].shape[0], 
                             normalize=True,
                             value_range=(-1, 1),)
            print('saved to ', f'{self.image_output_dir}/epi_{eval_id}_timestep_{timestep}_{key}.png')
            # np.save(f'wm-rollout-tool_hang_03-16-06-47-53-1/epi_{eval_id}_timestep_{timestep}_state.npy', next_state)
            # utils.save_image(gt_img_complete[key].squeeze(1),
            #                  f'wm-rollout-square_03_20_23_03_49_1/epi_{eval_id}_timestep_{timestep}_{key}_complete.png',
            #                  nrow=gt_img_complete[self.view_names[0]].shape[0],
            #                  normalize=True,
            #                  value_range=(-1, 1),)
            # print('image saved to ', f'wm_rollout/rollout_{timestep}_{key}.png')

        return next_obs, next_state

    def make_torch_obs(self, obs):
        current_obs = {}
        current_obs['robot0_eef_pos'] = torch.tensor(obs['robot0_eef_pos']).unsqueeze(0).unsqueeze(0)
        current_obs['robot0_eef_quat'] = torch.tensor(obs['robot0_eef_quat']).unsqueeze(0).unsqueeze(0)
        current_obs['robot0_gripper_qpos'] = torch.tensor(obs['robot0_gripper_qpos']).unsqueeze(0).unsqueeze(0)
        for view_name in self.view_names:
            current_obs[view_name] = torch.tensor(obs[view_name]).unsqueeze(0).unsqueeze(0)
            # print(f'{view_name}_image ', current_obs[view_name].shape)
        return current_obs


