"""
This script demonstrates how to load and rollout a finetuned Octo model.
We use the Octo model finetuned on ALOHA sim data from the examples/finetune_new_observation_action.py script.

For installing the ALOHA sim environment, clone: https://github.com/tonyzhaozh/act
Then run:
pip3 install opencv-python modern_robotics pyrealsense2 h5py_cache pyquaternion pyyaml rospkg pexpect mujoco==2.3.3 dm_control==1.0.9 einops packaging h5py

Finally, modify the `sys.path.append` statement below to add the ACT repo to your path.
If you are running this on a head-less server, start a virtual display:
    Xvfb :1 -screen 0 1024x768x16 &
    export DISPLAY=:1

To run this script, run:
    cd examples
    python3 03_eval_finetuned.py --filetuned_path=<path_to_finetuned_aloha_checkpoint>
"""
import sys

from absl import app, flags, logging
import gym
import jax
import numpy as np
import wandb


from octo.model.octo_model import OctoModel



############### MIKASA IMPORTS! ####################

from collections import defaultdict
import json
import os
import signal
import time
import numpy as np
from typing import Annotated, Optional
import gymnasium as gym
import numpy as np
import tyro
from dataclasses import dataclass
from pathlib import Path

import cv2
import mikasa_robo_suite
from mikasa_robo_suite.utils.wrappers import StateOnlyTensorToDictWrapper
from tqdm.notebook import tqdm
# import torch
# import gymnasium as gym

from mikasa_robo_suite.memory_envs import *
from mikasa_robo_suite.utils.wrappers import *


import copy
from typing import Dict
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils import common
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore', message='.*env\\.\\w+ to get variables from other wrappers is deprecated.*')


from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv

from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter
from colorama import Fore, Style

from gymnasium import spaces
from collections import deque
import logging
from typing import Dict, Optional, Sequence, Tuple
import jax
import numpy as np
import tensorflow as tf

import torch
import os
import jax.numpy as jnp


@dataclass
class Args:
    """
    This is a script to evaluate policies on real2sim environments. Example command to run: 

    XLA_PYTHON_CLIENT_PREALLOCATE=false python real2sim_eval_maniskill3.py \
        --model="octo-small" -e "PutEggplantInBasketScene-v1" -s 0 --num-episodes 192 --num-envs 64
    """

    exp_name: Optional[str] = None
    env_id: str = "RememberColor3-v0" #'RememberColor3-v0' #"ShellGamePush-v0"
    """The environment ID of the task you want to simulate. Can be one of
    PutCarrotOnPlateInScene-v1, PutSpoonOnTableClothInScene-v1, StackGreenCubeOnYellowCubeBakedTexInScene-v1, PutEggplantInBasketScene-v1"""

    shader: str = "default"

    num_envs: int = 1
    """Number of environments to run. With more than 1 environment the environment will use the GPU backend 
    which runs faster enabling faster large-scale evaluations. Note that the overall behavior of the simulation
    will be slightly different between CPU and GPU backends."""

    num_episodes: int = 100
    """Number of episodes to run and record evaluation metrics over"""

    record_dir: str = "videos"
    """The directory to save videos and results"""

    model: Optional[str] = 'octo-base' #"rt-1x"
    """The model to evaluate on the given environment. Can be one of octo-base, octo-small, rt-1x. If not given, random actions are sampled."""

    policy_setup = "google_robot" #"widowx_bridge"

    ckpt_path: str = "anon2SimplerEnv/octo/examples/mikasa_octo_finetuned_mikasa_remcol3" #"anon2SimplerEnv/octo/examples/mikasa_octo_finetuned_mikasa3_test" #"anon2SimplerEnv/octo/examples/mikasa_octo_finetuned_mikasa2"
    """Checkpoint path for models. Only used for RT models"""

    seed: Annotated[int, tyro.conf.arg(aliases=["-s"])] = 0
    """Seed the model and environment. Default seed is 0"""

    reset_by_episode_id: bool = True
    """Whether to reset by fixed episode ids instead of random sampling initial states."""

    info_on_video: bool = False
    """Whether to write info text onto the video"""

    save_video: bool = True
    """Whether to save videos"""

    debug: bool = False

    device: str = 'cuda:0'


    camera_width: Optional[int] = None
    """the width of the camera image. If none it will use the default the environment specifies"""
    camera_height: Optional[int] = None
    """the height of the camera image. If none it will use the default the environment specifies."""

    # to be filled in runtime
    grad_steps_per_iteration: int = 0
    """the number of gradient updates per iteration"""
    steps_per_env: int = 0
    """the number of steps each parallel env takes per iteration"""

    include_oracle: bool = False
    """if toggled, oracle info (such as cup_with_ball_number in ShellGamePush-v0) will be used during the training, i.e. reducing memory task to MDP"""
    noop_steps: int = 1
    """if = 1, then no noops, if > 1, then noops for t ~ [0, noop_steps-1]"""
    include_rgb: bool = True
    """if toggled, rgb images will be included in the observation space"""
    include_joints: bool = True
    """[works only with include_rgb=True] if toggled, joints will be included in the observation space"""
    reward_mode: str = 'normalized_dense' # sparse | normalized_dense
    """the mode of the reward function"""

    control_mode: Optional[str] = "pd_joint_delta_pos"
    """the control mode to use for the environment"""
    render_mode: str = "all"
    """the environment rendering mode"""

    """the id of the environment"""
    include_state: bool = False
    """whether to include state information in observations"""
    total_timesteps: int = 50_000_000
    """total timesteps of the experiments"""
    # obs_mode: str = "rgb"
    # """the observation mode to use"""
    # env_vectorization: str = "gpu"
    # """the type of environment vectorization to use"""
    """the number of parallel environments"""
    num_eval_envs: int = 16
    """the number of parallel evaluation environments"""
    partial_reset: bool = False
    """whether to let parallel environments reset upon termination instead of truncation"""
    eval_partial_reset: bool = False
    """whether to let parallel evaluation environments reset upon termination instead of truncation"""
    num_steps: int = 270
    """the number of steps to run in each environment per policy rollout"""
    num_eval_steps: int = 270

    episode_timeout: int = 200 #90#1000 #270
    """the number of steps to run in each evaluation environment during evaluation"""
    reconfiguration_freq: Optional[int] = None
    """how often to reconfigure the environment during training"""
    eval_reconfiguration_freq: Optional[int] = 1


    capture_video: bool = True
    evaluate: bool = True
    sim_backend: str = 'gpu'

    instruction: str = 'pick the green cube' #'pick the green ball' #'touch the green ball' #'touch the cup that is hiding the ball' #'touch the cup covering a red ball' #'touch the cup on top, under which there is a red ball'

    camera_width: Optional[int] = 128 #None
    """the width of the camera image. If none it will use the default the environment specifies"""
    camera_height: Optional[int] = 128 #None
    """the height of the camera image. If none it will use the default the environment specifies."""
    language_instruction: str = 'CHANGE ME'

def get_image_from_mikasa_obs_dict(env, obs, camera_name='base_camera'):
    import torch
    # obtain image from observation dictionary returned by ManiSkill environment
    camera_name = 'base_camera' #'hand_camera' #'base_camera'
    img = obs["sensor_data"][camera_name]["rgb"]
    return img.to(torch.uint8)

class CameraWrapper(gym.ObservationWrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

        sensor_data_space = env.observation_space.spaces['state'].spaces['sensor_data']
        base_cam_space = sensor_data_space.spaces['base_camera'].spaces['rgb']
        hand_cam_space = sensor_data_space.spaces['hand_camera'].spaces['rgb']

        # переопределяем observation_space
        self.observation_space = spaces.Dict({
            'image_primary': base_cam_space,
            'image_wrist': hand_cam_space,
        })

    def observation(self, obs: dict) -> dict:
        # собираем новый словарь
        new_obs = {
            'image_primary': obs['sensor_data']['base_camera']['rgb'],
            'image_wrist': obs['sensor_data']['hand_camera']['rgb'],
        }
        return new_obs

def stack_and_pad(history: deque, num_obs: int):
    """
    Converts a list of observation dictionaries (history) into a single observation dictionary
    by stacking the values. Adds a padding mask to the observation that denotes which timesteps
    represent padding based on the number of observations seen so far (num_obs).
    """
    horizon = len(history)
    full_obs = {}
    for k in history[0]:
        # Если значение — torch.Tensor на GPU, переводим на CPU и в NumPy
        vals = []
        for dic in history:
            v = dic[k]
            if isinstance(v, torch.Tensor):
                v = v.detach().cpu().numpy()
            vals.append(v)
        full_obs[k] = np.stack(vals)

    pad_length = horizon - min(num_obs, horizon)
    timestep_pad_mask = np.ones(horizon, dtype=np.float32)
    timestep_pad_mask[:pad_length] = 0
    full_obs["pad_mask"] = timestep_pad_mask

    return full_obs

def space_stack(space: gym.Space, repeat: int):
    """
    Creates new Gym space that represents the original observation/action space
    repeated `repeat` times.
    """

    if isinstance(space, gym.spaces.Box):
        return gym.spaces.Box(
            low=np.repeat(space.low[None], repeat, axis=0),
            high=np.repeat(space.high[None], repeat, axis=0),
            dtype=space.dtype,
        )
    elif isinstance(space, gym.spaces.Discrete):
        return gym.spaces.MultiDiscrete([space.n] * repeat)
    elif isinstance(space, gym.spaces.Dict):
        return gym.spaces.Dict(
            {k: space_stack(v, repeat) for k, v in space.spaces.items()}
        )
    else:
        raise ValueError(f"Space {space} is not supported by Octo Gym wrappers.")


def listdict2dictlist(LD):
    return {k: [dic[k] for dic in LD] for k in LD[0]}

class HistoryWrapper(gym.Wrapper):
    """
    Accumulates the observation history into `horizon` size chunks. If the length of the history
    is less than the length of the horizon, we pad the history to the full horizon length.
    A `timestep_pad_mask` key is added to the final observation dictionary that denotes which timesteps
    are padding.
    """

    def __init__(self, env: gym.Env, horizon: int):
        super().__init__(env)
        self.horizon = horizon

        self.history = deque(maxlen=self.horizon)
        self.num_obs = 0

        self.observation_space = space_stack(self.env.observation_space, self.horizon)

    def step(self, action):
        obs, reward, done, trunc, info = self.env.step(action)
        self.num_obs += 1
        self.history.append(obs)
        assert len(self.history) == self.horizon
        full_obs = stack_and_pad(self.history, self.num_obs)

        return full_obs, reward, done, trunc, info

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.num_obs = 1
        self.history.extend([obs] * self.horizon)
        full_obs = stack_and_pad(self.history, self.num_obs)

        return full_obs, info


class RHCWrapper(gym.Wrapper):
    """
    Performs receding horizon control. The policy returns `pred_horizon` actions and
    we execute `exec_horizon` of them.
    """

    def __init__(self, env: gym.Env, exec_horizon: int):
        super().__init__(env)
        self.exec_horizon = exec_horizon

    def step(self, actions):
        if self.exec_horizon == 1 and len(actions.shape) == 1:
            actions = actions[None]
        assert len(actions) >= self.exec_horizon
        rewards = []
        observations = []
        infos = []

        for i in range(self.exec_horizon):

            obs, reward, done, trunc, info = self.env.step(torch.Tensor(actions[i]))
            observations.append(obs)
            rewards.append(reward)
            infos.append(info)

            if done or trunc:
                break

        infos = listdict2dictlist(infos)
        infos["rewards"] = rewards
        infos["observations"] = observations

        return obs, torch.stack(rewards).mean().item(), done, trunc, infos


class TemporalEnsembleWrapper(gym.Wrapper):
    """
    Performs temporal ensembling from https://arxiv.org/abs/2304.13705
    At every timestep we execute an exponential weighted average of the last
    `pred_horizon` predictions for that timestep.
    """

    def __init__(self, env: gym.Env, pred_horizon: int, exp_weight: int = 0):
        super().__init__(env)
        self.pred_horizon = pred_horizon
        self.exp_weight = exp_weight

        self.act_history = deque(maxlen=self.pred_horizon)

        self.action_space = space_stack(self.env.action_space, self.pred_horizon)

    def step(self, actions):
        assert len(actions) >= self.pred_horizon

        self.act_history.append(actions[: self.pred_horizon])
        num_actions = len(self.act_history)

        # select the predicted action for the current step from the history of action chunk predictions
        curr_act_preds = np.stack(
            [
                pred_actions[i]
                for (i, pred_actions) in zip(
                    range(num_actions - 1, -1, -1), self.act_history
                )
            ]
        )

        # more recent predictions get exponentially *less* weight than older predictions
        weights = np.exp(-self.exp_weight * np.arange(num_actions))
        weights = weights / weights.sum()
        # compute the weighted average across all predictions for this timestep
        action = np.sum(weights[:, None] * curr_act_preds, axis=0)

        return self.env.step(action)

    def reset(self, **kwargs):
        self.act_history = deque(maxlen=self.pred_horizon)
        return self.env.reset(**kwargs)


class ResizeImageWrapper(gym.ObservationWrapper):
    """
    Resizes images from a robot environment to the size the model expects.

    We attempt to match the resizing operations done in the model's data pipeline.
    First, we resize the image using lanczos interpolation to match the resizing done
    when converting the raw data into RLDS. Then, we crop and resize the image with
    bilinear interpolation to match the average of the crop and resize image augmentation
    performed during training.
    """

    def __init__(
        self,
        env: gym.Env,
        resize_size: Optional[Dict[str, Tuple]] = None,
        augmented_keys: Sequence[str] = ("image_primary",),
        avg_scale: float = 0.9,
        avg_ratio: float = 1.0,
    ):
        super().__init__(env)
        assert isinstance(
            self.observation_space, gym.spaces.Dict
        ), "Only Dict observation spaces are supported."
        spaces = self.observation_space.spaces
        self.resize_size = resize_size
        self.augmented_keys = augmented_keys
        if len(self.augmented_keys) > 0:
            new_height = tf.clip_by_value(tf.sqrt(avg_scale / avg_ratio), 0, 1)
            new_width = tf.clip_by_value(tf.sqrt(avg_scale * avg_ratio), 0, 1)
            height_offset = (1 - new_height) / 2
            width_offset = (1 - new_width) / 2
            self.bounding_box = tf.stack(
                [
                    height_offset,
                    width_offset,
                    height_offset + new_height,
                    width_offset + new_width,
                ],
            )

        if resize_size is None:
            self.keys_to_resize = {}
        else:
            self.keys_to_resize = {
                f"image_{i}": resize_size[i] for i in resize_size.keys()
            }
        logging.info(f"Resizing images: {self.keys_to_resize}")
        for k, size in self.keys_to_resize.items():
            spaces[k] = gym.spaces.Box(
                low=0,
                high=255,
                shape=size + (3,),
                dtype=np.uint8,
            )
        self.observation_space = gym.spaces.Dict(spaces)

    def observation(self, observation):
        for k, size in self.keys_to_resize.items():
            image = tf.image.resize(
                observation[k], size=size, method="lanczos3", antialias=True
            )

            # if this image key was augmented with random resizes and crops,
            # we perform the average of the augmentation here
            if k in self.augmented_keys:
                image = tf.image.crop_and_resize(
                    image[None], self.bounding_box[None], [0], size
                )[0]

            image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8).numpy()

            observation[k] = image
        return observation


class NormalizeProprio(gym.ObservationWrapper):
    """
    Un-normalizes the proprio.
    """

    def __init__(
        self,
        env: gym.Env,
        action_proprio_metadata: dict,
    ):
        self.action_proprio_metadata = jax.tree_map(
            lambda x: np.array(x),
            action_proprio_metadata,
            is_leaf=lambda x: isinstance(x, list),
        )
        super().__init__(env)

    def normalize(self, data, metadata):
        mask = metadata.get("mask", np.ones_like(metadata["mean"], dtype=bool))
        return np.where(
            mask,
            (data - metadata["mean"]) / (metadata["std"] + 1e-8),
            data,
        )

    def observation(self, obs):
        if "proprio" in self.action_proprio_metadata:
            obs["proprio"] = self.normalize(
                obs["proprio"], self.action_proprio_metadata["proprio"]
            )
        else:
            assert "proprio" not in obs, "Cannot normalize proprio without metadata."
        return obs


args = tyro.cli(Args)
if args.seed is not None:
    np.random.seed(args.seed)

TIME = time.strftime('%Y%m%d_%H%M%S')

# setup wandb for logging
wandb.init(name=f"eval_octo_{args.env_id}", project="Mikasa-Robo-VLA")

# load finetuned model
logging.info("Loading finetuned model...")
model = OctoModel.load_pretrained(args.ckpt_path)


if args.env_id in ['ShellGamePush-v0', 'ShellGamePick-v0', 'ShellGameTouch-v0']:
    wrappers_list = [
        (InitialZeroActionWrapper, {"n_initial_steps": args.noop_steps-1}),
        (RenderStepInfoWrapper, {}),
        (ShellGameRenderCupInfoWrapper, {}),
        (RenderRewardInfoWrapper, {}),
        (DebugRewardWrapper, {}),
    ]
    oracle_info = 'cup_with_ball_number'
    prompt_info = None
elif args.env_id in ['InterceptSlow-v0', 'InterceptMedium-v0', 'InterceptFast-v0', 
                        'InterceptGrabSlow-v0', 'InterceptGrabMedium-v0', 'InterceptGrabFast-v0']:
    wrappers_list = [
        (InitialZeroActionWrapper, {"n_initial_steps": args.noop_steps-1}),
        (RenderStepInfoWrapper, {}),
        (RenderRewardInfoWrapper, {}),
        (DebugRewardWrapper, {}),
    ]
    oracle_info = None
    prompt_info = None
elif args.env_id in ['RotateLenientPos-v0', 'RotateLenientPosNeg-v0',
                        'RotateStrictPos-v0', 'RotateStrictPosNeg-v0']:
    wrappers_list = [
        (InitialZeroActionWrapper, {"n_initial_steps": args.noop_steps-1}),
        (RenderStepInfoWrapper, {}),
        (RenderRewardInfoWrapper, {}),
        (RotateRenderAngleInfoWrapper, {}),
        (DebugRewardWrapper, {}),
    ]
    oracle_info = 'angle_diff'
    prompt_info = 'target_angle'
elif args.env_id in ['CameraShutdownPush-v0', 'CameraShutdownPick-v0']:
    wrappers_list = [
        (InitialZeroActionWrapper, {"n_initial_steps": args.noop_steps-1}),
        (CameraShutdownWrapper, {"n_initial_steps": 19}), # camera works only for t ~ [0, 19]
        (RenderStepInfoWrapper, {}),
        (RenderRewardInfoWrapper, {}),
    ]
    oracle_info = None
    prompt_info = None
elif args.env_id in ['TakeItBack-v0']:
    wrappers_list = [
        (InitialZeroActionWrapper, {"n_initial_steps": args.noop_steps-1}),
        (RenderStepInfoWrapper, {}),
        (RenderRewardInfoWrapper, {}),
        (DebugRewardWrapper, {}),
    ]
    oracle_info = None
    prompt_info = None
elif args.env_id in ['RememberColor3-v0', 'RememberColor5-v0', 'RememberColor9-v0']:
    wrappers_list = [
        (InitialZeroActionWrapper, {"n_initial_steps": args.noop_steps-1}),
        (RememberColorInfoWrapper, {}),
        (RenderStepInfoWrapper, {}),
        (RenderRewardInfoWrapper, {}),
        (DebugRewardWrapper, {}),
    ]
    oracle_info = None
    prompt_info = None
elif args.env_id in ['RememberShape3-v0', 'RememberShape5-v0', 'RememberShape9-v0']:
    wrappers_list = [
        (InitialZeroActionWrapper, {"n_initial_steps": args.noop_steps-1}),
        (RememberShapeInfoWrapper, {}),
        (RenderStepInfoWrapper, {}),
        (RenderRewardInfoWrapper, {}),
        (DebugRewardWrapper, {}),
    ]
    oracle_info = None
    prompt_info = None
elif args.env_id in ['RememberShapeAndColor3x2-v0', 'RememberShapeAndColor3x3-v0', 'RememberShapeAndColor5x3-v0']:
    wrappers_list = [
        (InitialZeroActionWrapper, {"n_initial_steps": args.noop_steps-1}),
        (RememberShapeAndColorInfoWrapper, {}),
        (RenderStepInfoWrapper, {}),
        (RenderRewardInfoWrapper, {}),
        (DebugRewardWrapper, {}),
    ]
    oracle_info = None
    prompt_info = None
elif args.env_id in ['BunchOfColors3-v0', 'BunchOfColors5-v0', 'BunchOfColors7-v0']:
    wrappers_list = [
        (InitialZeroActionWrapper, {"n_initial_steps": args.noop_steps-1}),
        (MemoryCapacityInfoWrapper, {}),
        (RenderStepInfoWrapper, {}),
        (RenderRewardInfoWrapper, {}),
        (DebugRewardWrapper, {}),
    ]
    oracle_info = None
    prompt_info = None
elif args.env_id in ['SeqOfColors3-v0', 'SeqOfColors5-v0', 'SeqOfColors7-v0']:
    wrappers_list = [
        (InitialZeroActionWrapper, {"n_initial_steps": args.noop_steps-1}),
        (MemoryCapacityInfoWrapper, {}),
        (RenderStepInfoWrapper, {}),
        (RenderRewardInfoWrapper, {}),
        (DebugRewardWrapper, {}),
    ]
    oracle_info = None
    prompt_info = None
elif args.env_id in ['ChainOfColors3-v0', 'ChainOfColors5-v0', 'ChainOfColors7-v0']:
    wrappers_list = [
        (InitialZeroActionWrapper, {"n_initial_steps": args.noop_steps-1}),
        (MemoryCapacityInfoWrapper, {}),
        (RenderStepInfoWrapper, {}),
        (RenderRewardInfoWrapper, {}),
        (DebugRewardWrapper, {}),
    ]
    oracle_info = None
    prompt_info = None
else:
    raise ValueError(f"Unknown environment: {args.env_id}")






print('\n' + '='*75)
print('║' + ' '*24 + 'Environment Configuration' + ' '*24 + '║')
print('='*75)
print('║' + f' Environment ID: {args.env_id}'.ljust(73) + '║')
print('║' + f' Oracle Info:    {oracle_info}'.ljust(73) + '║')
print('║ Wrappers:'.ljust(74) + '║')
for wrapper, kwargs in wrappers_list:
    print('║    ├─ ' + wrapper.__name__.ljust(65) + '║')
    if kwargs:
        print('║    │  └─ ' + str(kwargs).ljust(65) + '║')
print('║' + '-'*73 + '║')

state_msg = 'state will be used' if args.include_state else 'state will not be used'
print('║' + f' include_state:       {str(args.include_state):<5} │ {state_msg}'.ljust(68) + '║')

rgb_msg = 'rgb images will be used' if args.include_rgb else 'rgb images will not be used'
print('║' + f' include_rgb:         {str(args.include_rgb):<5} │ {rgb_msg}'.ljust(68) + '║')

oracle_msg = 'oracle info will be used' if args.include_oracle else 'oracle info will not be used'
print('║' + f' include_oracle:      {str(args.include_oracle):<5} │ {oracle_msg}'.ljust(68) + '║')

joints_msg = 'joints will be used' if args.include_joints else 'joints will not be used'
print('║' + f' include_joints:      {str(args.include_joints):<5} │ {joints_msg}'.ljust(68) + '║')
print('='*75 + '\n')

assert any([args.include_state, args.include_rgb]), "At least one of include_state or include_rgb must be True."
assert not (args.include_joints and not args.include_rgb), "include_joints can only be True when include_rgb is True"

if args.include_state and not args.include_rgb and not args.include_oracle and not args.include_joints:
    MODE = 'state'
elif args.include_state and args.include_rgb and not args.include_oracle and not args.include_joints:
    raise NotImplementedError("state_rgb is not implemented and does not make sense, since any environment can be solved only by using state")
    MODE = 'state_rgb'
elif args.include_state and not args.include_rgb and args.include_oracle and not args.include_joints:
    raise NotImplementedError("state_oracle is not implemented and does not make sense, since the state already contains oracle information")
    MODE = 'state_oracle'
elif args.include_state and args.include_rgb and args.include_oracle and not args.include_joints:
    raise NotImplementedError("state_rgb_oracle is not implemented and does not make sense, since any environment can be solved only by using state")
    MODE = 'state_rgb_oracle'
elif not args.include_state and args.include_rgb and not args.include_oracle and not args.include_joints:
    MODE = 'rgb'
elif not args.include_state and args.include_rgb and args.include_oracle and not args.include_joints:
    MODE = 'rgb_oracle'
elif not args.include_state and args.include_rgb and args.include_joints and args.include_oracle:
    MODE = 'rgb_joints_oracle' # TODO: check if this is correct
elif not args.include_state and args.include_rgb and args.include_joints and not args.include_oracle:
    MODE = 'rgb_joints'
else:
    raise NotImplementedError(f"Unknown mode: {args.include_state=} {args.include_rgb=} {args.include_oracle=} {args.include_joints=}")

SAVE_DIR = f'checkpoints/ppo_memtasks/{MODE}/{args.reward_mode}/{args.env_id}'


print(f'{MODE=}')
print(f'{prompt_info=}')

wrappers_list.insert(0, (StateOnlyTensorToDictWrapper, {})) # obs=torch.tensor -> dict with keys: state: obs, prompt: prompt, oracle_info: oracle_info


if args.exp_name is None:
    args.exp_name = 'asdfsfd'#os.path.basename(__file__)[: -len(".py")]
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{MODE}__{TIME}"
else:
    # run_name = args.exp_name
    run_name = f"{args.exp_name}__{args.seed}__{MODE}__{TIME}"


env_kwargs = dict(sensor_configs=dict()) # render_mode="rgb_array",

if args.camera_width is not None:
    env_kwargs["sensor_configs"]["width"] = args.camera_width
if args.camera_height is not None:
    env_kwargs["sensor_configs"]["height"] = args.camera_height



args.num_envs = 1
env_kwargs["sensor_configs"]["shader_pack"] = args.shader
env = gym.make(
    args.env_id,
    obs_mode="rgb",
    num_envs=args.num_envs,
    render_mode="all",
    control_mode= 'pd_ee_delta_pose',
    reconfiguration_freq = 1,
    sim_backend = args.sim_backend,
    **env_kwargs
)

env = ManiSkillVectorEnv(env, args.num_envs, ignore_terminations=True, record_metrics=True)




for wrapper_class, wrapper_kwargs in wrappers_list:
    env = wrapper_class(env, **wrapper_kwargs)
    # envs = wrapper_class(envs, **wrapper_kwargs)

args.capture_video = True
if args.capture_video:
    # eval_output_dir = f"{SAVE_DIR}/{run_name}/{TIME}/videossss"
    eval_output_dir = f"octo_eval/{args.env_id}/{TIME}/videos"

    # if args.evaluate:
    #     eval_output_dir = f"{os.path.dirname(args.checkpoint)}/test_videos"
    print(f"Saving eval videos to {eval_output_dir}")

    env = RecordEpisode(
        env,
        output_dir=eval_output_dir,
        save_trajectory=False,
        save_video=True,         
        save_video_trigger=None,  
        # save_on_reset=True      
    )


env = CameraWrapper(env)
env = HistoryWrapper(env, horizon=10)
env = RHCWrapper(env, exec_horizon=1)

print("JAX backend:", jax.default_backend())     
print("JAX devices:", jax.devices())              


fps = 10
eval_metrics = defaultdict(list)
num_episodes = 0
num_eval_episodes = 100

seeds = list(range(1, num_eval_episodes + 1))

for j in range(num_eval_episodes):

    print(f'Eval episode {j}')

    language_instruction = args.language_instruction 
    task = model.create_tasks(texts=[language_instruction])

    obs, info = env.reset(seed = seeds[j], options={})


    for i in range(args.num_eval_steps):
        obs['pad_mask'] = np.expand_dims(obs['pad_mask'], 0)
        obs['image_primary'] = np.swapaxes(obs['image_primary'], 0, 1)
        obs['image_wrist'] = np.swapaxes(obs['image_wrist'], 0, 1)

        action = model.sample_actions(obs, task, rng=jax.random.PRNGKey(0))
        # unnormalization_statistics = model.dataset_statistics["action"]
        action_mean = model.dataset_statistics["action"]["mean"]
        action_std = model.dataset_statistics["action"]["std"]
        action_mean = jnp.array(action_mean)
        action_std = jnp.array(action_std)
        actions = action * action_std[None] + action_mean[None]

        actions = actions[0]

        obs, reward, done, trunc, info = env.step(np.array(actions))

        if "final_info" in info:
            mask = info["_final_info"][0]
            num_episodes += mask.sum()
            for k, v in info["final_info"][0]["episode"].items():
                eval_metrics[k].append(float(v.item()))
                wandb.log({k:float(v.item())}, step = j)
                wandb.log({f"mean_{k}": np.array(eval_metrics[k]).mean()}, step = j)

            break

    succ = eval_metrics['success_once'][-1]
    print(f'episode {j} success_once: {succ}')



