import pickle
from dataclasses import dataclass
import os

from constants import DINO_SIZE, RESNET_SIZE
from push_t_env import PushTEnv
os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'

import sys
import torch
import random
import gym

import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
from robomimic.utils.file_utils import get_env_metadata_from_dataset

import yaml
import importlib

import numpy as np
from fast_scaler import FastScaler
from PIL import Image
from typing import List, Tuple, Dict, Any
#from vaes.models.vanilla_vae import VanillaVAE
#from r3m import load_r3m

# original_stdout = sys.stdout
# original_stderr = sys.stderr
# sys.stdout = open(os.devnull, 'w')
# sys.stderr = open(os.devnull, 'w')
#
# warnings.filterwarnings('ignore')

try:
    profile
except NameError:
    def profile(func):
        return func

# try:
import mimicgen.utils.robomimic_utils as RobomimicUtils
# d4rl sets some logging setting, let's undo them
import logging
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
from logging_utils import logger
# finally:
#     sys.stdout.close()
#     sys.stderr.close()
#     sys.stdout = original_stdout
#     sys.stderr = original_stderr

import cv2
#from groundingdino.util.inference import load_model, load_image, predict, annotate
#import groundingdino.datasets.transforms as T
DEBUG = False

TWO_PI = 2 * np.pi
INV_TWO_PI = 1 / TWO_PI

# To be populated if needed
dino_model = None
r3m = None
vae = None
resnet = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_transform = None
img_env = None
camera_name = None
pca = None

grounding_dino_model = None
last_grounding_results = {}
vision_ob = torch.tensor([])
trackers = {}
last_boxes = {}
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25

tapir = None
online_model_init = None
online_model_predict = None

query_features = []
causal_state = []
last_tracks = []
keypoint_viz = []

proprio_tensor_cpu = torch.tensor([]) # Pinned memory on CPU
frame_tensor_cpu = torch.tensor([]) # Pinned memory on CPU
ret_tensor = torch.tensor([])

flattened_obs_norm_squared = []
val_flattened_obs_norm_squared = []

@dataclass
class Dataset:
    name: str
    obs_scaler: FastScaler | None # None if not normalized
    act_scaler: FastScaler
    rot_indices: np.ndarray | torch.Tensor
    weights: np.ndarray | torch.Tensor
    non_rot_indices: np.ndarray | torch.Tensor
    obs_matrix: List
    act_matrix: List
    traj_starts: np.ndarray | torch.Tensor
    flattened_obs_matrix: np.ndarray | torch.Tensor
    flattened_act_matrix: np.ndarray | torch.Tensor
    processed_obs_matrix: np.ndarray | torch.Tensor

def hash_tensor(tensor: torch.Tensor):
    import hashlib

    return hashlib.sha256(tensor.cpu().detach().numpy().tobytes()).hexdigest()

def load_and_scale_data(path, rot_indices, weights, ob_type='state', scale=True, bc=False, device='cuda', num_viewpoints=2):
    expert_data = load_expert_data(path)
    
    is_numpy = isinstance(expert_data[0]['observations'][0], np.ndarray)
    if is_numpy:
        observations = np.concatenate([traj['observations'] for traj in expert_data])
        # error = 0.0
        # for ob1, ob2 in zip(observations, torch.from_numpy(observations)):
        #     for o1, o2 in zip(ob1, ob2):
        #         error += abs(o1 - o2.item())
        observations = torch.from_numpy(observations)
    else:
        observations = torch.concatenate([traj['observations'] for traj in expert_data])

    rot_indices = torch.tensor(rot_indices, dtype=torch.int32)
    # Separate non-rotational dimensions
    non_rot_indices = torch.tensor([i for i in range(observations.shape[-1]) if i not in rot_indices], dtype=torch.int32)

    if scale:
        obs_scaler = FastScaler()
        obs_scaler.fit(observations)

        if ob_type == 'retrieval':
            obs_scaler.mean_np[rot_indices] = 0.0
            obs_scaler.mean_torch[rot_indices] = 0.0
            obs_scaler.scale_np[rot_indices] = 1.0
            obs_scaler.scale_torch[rot_indices] = 1.0

        #if ob_type == 'rgb':
        #    IMG_SIZE = 224 * 224 * 3
        #    obs_scaler.mean_np[-IMG_SIZE:] = 0
        #    obs_scaler.scale_np[-IMG_SIZE:] = 1
        #    obs_scaler.mean_torch = torch.as_tensor(obs_scaler.mean_np)
        #    obs_scaler.scale_torch = torch.as_tensor(obs_scaler.scale_np)
        if ob_type == 'resnet':
            RESNET_FEATURE_SIZE = 512
            obs_scaler.mean_np[-RESNET_FEATURE_SIZE * num_viewpoints:] = 0
            obs_scaler.scale_np[-RESNET_FEATURE_SIZE * num_viewpoints:] = 1
            obs_scaler.mean_torch = torch.as_tensor(obs_scaler.mean_np)
            obs_scaler.scale_torch = torch.as_tensor(obs_scaler.scale_np)
        # if ob_type == 'r3m':
        #     RESNET_FEATURE_SIZE = 512
        #     obs_scaler.mean_np[-RESNET_FEATURE_SIZE * num_viewpoints:] = 0
        #     obs_scaler.scale_np[-RESNET_FEATURE_SIZE * num_viewpoints:] = 1
        #     obs_scaler.mean_torch = torch.as_tensor(obs_scaler.mean_np)
        #     obs_scaler.scale_torch = torch.as_tensor(obs_scaler.scale_np)
        if ob_type == 'dino':
            DINO_SIZE = 768
            obs_scaler.mean_np[-DINO_SIZE * num_viewpoints:] = 0
            obs_scaler.scale_np[-DINO_SIZE * num_viewpoints:] = 1
            obs_scaler.mean_torch = torch.as_tensor(obs_scaler.mean_np)
            obs_scaler.scale_torch = torch.as_tensor(obs_scaler.scale_np)

        for traj in expert_data:
            traj['observations'] = obs_scaler.transform(torch.from_numpy(traj['observations']) if is_numpy else traj['observations'])
                
        new_path = path[:-4] + '_standardized.pkl'
        save_expert_data(expert_data, new_path)
    else:
        new_path = path
        obs_scaler = None

    act_scaler = FastScaler()
    if isinstance(expert_data[0]['actions'][0], np.ndarray):
        act_scaler.fit(np.concatenate([traj['actions'] for traj in expert_data]))
    else:
        act_scaler.fit(torch.concatenate([traj['actions'] for traj in expert_data]))
    
    obs_matrix, act_matrix, traj_starts = create_matrices(expert_data, use_torch=True)
    
    if not bc:
        flattened_obs_matrix = torch.cat([obs for obs in obs_matrix], dim=0).to(device)
        flattened_act_matrix = torch.cat([act for act in act_matrix], dim=0).to(device)
        
        if len(weights) > 0:
            weights = torch.as_tensor(weights, dtype=torch.float32, device=device)
            processed_obs_matrix = flattened_obs_matrix[:, non_rot_indices] * torch.as_tensor(weights[non_rot_indices], dtype=flattened_obs_matrix[0][0].dtype)
        else:
            weights = torch.ones(obs_matrix[0][0].shape[0], dtype=torch.float32, device=device)
            processed_obs_matrix = flattened_obs_matrix


        traj_starts = torch.as_tensor(traj_starts)
    else:
        flattened_obs_matrix = None
        flattened_act_matrix = None
        weights = None
        processed_obs_matrix = None
        traj_starts = None
        
    return Dataset(new_path, obs_scaler, act_scaler, rot_indices, weights, non_rot_indices, obs_matrix, act_matrix, traj_starts, flattened_obs_matrix, flattened_act_matrix, processed_obs_matrix)

def online_model_init_func(frames, query_points):
    """Initialize query features for the query points."""
    #jax.config.update('jax_default_prng_impl', 'threefry2x32')  # Use a deterministic PRNG
    #key = jax.random.PRNGKey(42)
    #jax.config.update("jax_enable_x64", True)
    frames = model_utils.preprocess_frames(frames)[np.newaxis, np.newaxis, :, :, :]
    feature_grids = tapir.get_feature_grids(frames, is_training=False)
    query_features = tapir.get_query_features(
      frames,
      is_training=False,
      query_points=query_points,
      feature_grids=feature_grids,
    )
    return query_features

def online_model_predict_func(frames, query_features, causal_context):
    """Compute point tracks and occlusions given frames and query points."""
    frames = model_utils.preprocess_frames(frames)[np.newaxis, np.newaxis, :, :, :]
    feature_grids = tapir.get_feature_grids(frames, is_training=False)
    trajectories = tapir.estimate_trajectories(
      frames.shape[-3:-1],
      is_training=False,
      feature_grids=feature_grids,
      query_features=query_features,
      query_points_in_video=None,
      query_chunk_size=64,
      causal_context=causal_context,
      get_causal_context=True,
    )
    causal_context = trajectories['causal_context']
    del trajectories['causal_context']
    # Take only the predictions for the final resolution.
    # For running on higher resolution, it's typically better to average across
    # resolutions.
    tracks = trajectories['tracks'][-1]
    occlusions = trajectories['occlusion'][-1]
    uncertainty = trajectories['expected_dist'][-1]
    visibles = model_utils.postprocess_occlusions(occlusions, uncertainty)
    return tracks, visibles, causal_context

def init_tapir():
    global tapir, online_model_init, online_model_predict, tapir_model, model_utils, jax
    from tapnet.models import tapir_model
    from tapnet.utils import model_utils
    import jax
    jax.config.update("jax_default_matmul_precision", "highest") # Crucial for determinism

    checkpoint_path = './model_checkpoints/tapnet/checkpoints/causal_bootstapir_checkpoint.npy'
    ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
    params, state = ckpt_state['params'], ckpt_state['state']

    kwargs = dict(
        use_causal_conv=True,
        bilinear_interp_with_depthwise_conv=False,
        pyramid_level=0,
    )

    kwargs.update(
        dict(pyramid_level=1, extra_convs=True, softmax_temperature=10.0)
    )

    tapir = tapir_model.ParameterizedTAPIR(params, state, tapir_kwargs=kwargs)
    online_model_init = jax.jit(online_model_init_func)
    online_model_predict = jax.jit(online_model_predict_func)

def get_object_pixel_coords(sim, obj_name, camera_name="agentview", offset=np.array([0, 0, 0]), obj_size_ratio=False):
    obj_id = sim.model.geom_name2id(obj_name)
    obj_size = sim.model.geom_size[obj_id]
    if obj_size_ratio:
        obj_pos = sim.data.geom_xpos[obj_id] + obj_size * offset
    else:
        obj_pos = sim.data.geom_xpos[obj_id] + offset

    cam_id = sim.model.camera_name2id(camera_name)
    cam_pos = sim.data.cam_xpos[cam_id]
    cam_mat = sim.data.cam_xmat[cam_id].reshape(3, 3)

    obj_pos_cam = cam_mat.T @ (obj_pos - cam_pos)

    height, width = 256, 256
    fovy = sim.model.cam_fovy[cam_id]
    f = (height / 2) / np.tan(np.deg2rad(fovy) / 2)

    x, y, z = obj_pos_cam

    u = int(width / 2 + (f * x / z))
    v = int(height / 2 - (f * y / z))

    # y, x
    return (height - v, width - u)

def get_joint_pixel_coords(sim, joint_name, body_name, camera_name="track"):
    body_id = sim.model.body_name2id(body_name)
    body_pos = sim.data.body_xpos[body_id]

    joint_id = sim.model.joint_name2id(joint_name)
    joint_offset = sim.model.jnt_pos[joint_id]

    joint_pos = body_pos + joint_offset

    cam_id = sim.model.camera_name2id(camera_name)
    cam_pos = sim.data.cam_xpos[cam_id]
    cam_mat = sim.data.cam_xmat[cam_id].reshape(3, 3)

    obj_pos_cam = cam_mat.T @ (joint_pos - cam_pos)

    height, width = 256, 256
    fovy = sim.model.cam_fovy[cam_id]
    f = (height / 2) / np.tan(np.deg2rad(fovy) / 2)

    x, y, z = obj_pos_cam

    u = int(width / 2 + (f * x / z))
    v = int(height / 2 - (f * y / z))

    # y, x
    return (height - v, width - u)

def crop_and_resize(img, crop_corners) -> np.ndarray:
    if len(crop_corners) == 0:
        return img

    width, height = img.shape[:2]
    cropped_img = img[crop_corners[0][1]:crop_corners[1][1], crop_corners[0][0]:crop_corners[1][0], :]
    resized_img = cv2.resize(cropped_img, (width, height))
    return resized_img

class MaxAndSkipEnv(gym.Wrapper):
    """
    Return only every ``skip``-th frame (frameskipping).

    Adapted from Stable-Baselines3.

    Args:
        env (`gym.Env`):
            The environment to wrap.
        skip (`int`):
            The number of frames to skip.
    """

    def __init__(self, env: gym.Env, skip: int = 4) -> None:
        super().__init__(env)
        # Most recent raw observations (for max pooling across time steps)
        self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype)
        self.skip = skip

    def step(self, action: int):
        """
        Step the environment with the given action
        Repeat action, sum reward, and max over last observations.
        :param action: the action
        :return: observation, reward, terminated, truncated, information
        """
        total_reward = 0.0
        info = {}
        terminated = truncated = False
        for i in range(self.skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            if i == self.skip - 2:
                self._obs_buffer[0] = obs
            if i == self.skip - 1:
                self._obs_buffer[1] = obs
            total_reward += reward
            if terminated | truncated:
                break
        # Note that the observation on the done=True frame doesn't matter
        max_frame = self._obs_buffer.max(axis=0)

        return max_frame, total_reward, terminated, truncated, info

class NoopResetEnv(gym.Wrapper):
    """
    Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0.

    Adapted from Stable-Baselines3.

    Args:
        env (`gym.Env`):
            The environment to wrap.
        noop_max (`int`):
            The maximum number of no-ops to perform.
    """

    def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
        super().__init__(env)
        self.noop_max = noop_max
        self.noop_action = 0
        assert env.unwrapped.get_action_meanings()[self.noop_action] == "NOOP"

    def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
        self.env.reset(**kwargs)
        noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
        for _ in range(noops):
            observation, reward, terminated, truncated, info = self.env.step(self.noop_action)
            if terminated | truncated:
                observation, info = self.env.reset(**kwargs)
        return observation, info

class AtariDictObservationWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Dict(
            {"image_observation": gym.spaces.Box(low=0, high=255, shape=(84, 84, 4), dtype=np.uint8)}
        )

    def observation(self, observation):
        observation = np.transpose(observation, (1, 2, 0))  # make channel last
        return {"image_observation": observation}

class NumpyObsWrapper(gym.ObservationWrapper):
    """
    RL algorithm generally expects numpy arrays or Tensors as observations. Atari envs for example return
    LazyFrames which need to be converted to numpy arrays before we actually use them.
    """

    def observation(self, observation: Any) -> np.ndarray:
        return np.array(observation)

def robotwin_reset(env, env_name, seed=42, gpu_id=0):
    from robotwin.envs import CONFIGS_PATH
    with open("robotwin/policy/Your_Policy/deploy_policy.yml", "r", encoding="utf-8") as f:
        usr_args = yaml.safe_load(f)

    usr_args.update({"task_name": env_name, "task_config": "demo_clean"})
    #usr_args.update({"task_name": env_name, "task_config": "demo_randomized"})

    task_config = usr_args["task_config"]

    with open(f"./robotwin/task_config/{task_config}.yml", "r", encoding="utf-8") as f:
        args = yaml.load(f.read(), Loader=yaml.FullLoader)

    args['task_name'] = env_name
    args["task_config"] = task_config

    embodiment_type = args.get("embodiment")
    embodiment_config_path = os.path.join(CONFIGS_PATH, "_embodiment_config.yml")

    with open(embodiment_config_path, "r", encoding="utf-8") as f:
        _embodiment_types = yaml.load(f.read(), Loader=yaml.FullLoader)

    with open(CONFIGS_PATH + "_camera_config.yml", "r", encoding="utf-8") as f:
        _camera_config = yaml.load(f.read(), Loader=yaml.FullLoader)

    head_camera_type = args["camera"]["head_camera_type"]
    args["head_camera_h"] = _camera_config[head_camera_type]["h"]
    args["head_camera_w"] = _camera_config[head_camera_type]["w"]

    args["left_robot_file"] = _embodiment_types[embodiment_type[0]]["file_path"]
    args["right_robot_file"] = _embodiment_types[embodiment_type[0]]["file_path"]
    args["dual_arm_embodied"] = True

    with open(args["left_robot_file"] + "/config.yml", "r", encoding="utf-8") as f:
        args["left_embodiment_config"] = yaml.load(f.read(), Loader=yaml.FullLoader)

    with open(args["right_robot_file"] + "/config.yml", "r", encoding="utf-8") as f:
        args["right_embodiment_config"] = yaml.load(f.read(), Loader=yaml.FullLoader)

    usr_args["left_arm_dim"] = len(args["left_embodiment_config"]["arm_joints_name"][0])
    usr_args["right_arm_dim"] = len(args["right_embodiment_config"]["arm_joints_name"][1])

    if seed is not None:
        if env_name == "pick_diverse_bottles":
            bad_seeds = [
                0, 100, 3, 103, 203, 303, 403, 4, 5, 7, 12, 112, 212, 312, 412, 15, 115, 215, 
                16, 116, 17, 117, 18, 20, 120, 220, 320, 25, 125, 27, 28, 128, 228, 31, 131, 
                32, 35, 135, 235, 335, 435, 535, 36, 37, 137, 237, 337, 437, 39, 139, 239, 
                40, 41, 141, 43, 44, 144, 244, 344, 444, 544, 47, 147, 49, 149, 249, 53, 153, 
                54, 57, 59, 60, 62, 63, 163, 263, 65, 165, 265, 66, 166, 68, 168, 268, 70, 
                170, 72, 73, 74, 174, 76, 78, 79, 80, 83, 183, 85, 185, 285, 86, 87, 89, 189, 
                90, 190, 91, 191, 291, 391, 491, 591, 92, 192, 93, 94, 194, 294, 96, 98, 99
            ]
        else:
            bad_seeds = []

        while seed in bad_seeds:
            seed += 100
        args['seed'] = seed
        args['use_seed'] = True

    args['device_id'] = gpu_id
    # args['shadow'] = False

    env.setup_demo(is_test=True, **args)

#@profile
def construct_env(config, seed=None, gpu_id=0):
    global pca
    is_robosuite = config.get('robosuite', False)

    if is_robosuite:
        global camera_name
        dummy_spec = dict(
            obs=dict(
                    low_dim=["robot0_eef_pos"],
                    rgb=[],
                ),
        )

        ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)

        env_meta = get_env_metadata_from_dataset(dataset_path=config['demo_hdf5'])
        env_meta['env_kwargs']['hard_reset'] = False
        env_meta['env_kwargs']['render_gpu_device_id'] = gpu_id
        env_meta['env_kwargs']['reward_shaping'] = config.get("reward_shaping", False)
        if seed is not None:
            env_meta['seed'] = seed
        print(env_meta)
        env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render_offscreen=True)

        camera_name = RobomimicUtils.get_default_env_cameras(env_meta=env_meta)[0]
        return env

    env_name = config['name']

    is_metaworld = config.get('metaworld', False)
    is_robotwin = config.get('robotwin', False)
    img = config.get('img', False)

    if env_name.startswith("ALE"):
        import ale_py
        env = gym.make(env_name, render_mode="rgb_array", frameskip=1)
        env.reward_range = (-np.inf, np.inf)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, num_stack=4)
        env = NumpyObsWrapper(env)
        env = AtariDictObservationWrapper(env)
        return env
    elif is_metaworld:
        env = _env_dict.MT50_V2[env_name]()
        env._partially_observable = False
        env._freeze_rand_vec = False
        env._set_task_called = True
    elif is_robotwin:
        envs_module = importlib.import_module(f"robotwin.envs.{env_name}")
        env_class = getattr(envs_module, env_name)
        env = env_class()

        # robotwin_reset(env, env_name, seed=seed, gpu_id=gpu_id)

    elif env_name == 'push_t':
        env = PushTEnv()
    else:
        import d4rl
        env = gym.make(env_name)
        if env_name == "hopper-expert-v2":
            distinct_colors = [
                [1, 0, 0, 1],  # Red
                [0, 1, 0, 1],  # Green
                [0, 0, 1, 1],  # Blue
                [1, 1, 0, 1],  # Yellow
                [1, 0, 1, 1],  # Magenta
            ]

            geoms = env.sim.model.geom_names
            for i, geom in enumerate(geoms):
                geom_id = env.sim.model.geom_name2id(geom)
                if geom == 'floor':
                    floor_mat_id = env.sim.model.geom_matid[geom_id]
                    env.sim.model.geom_matid[geom_id] = -1
                    env.sim.model.geom_rgba[geom_id] = [1, 1, 1, 1]
                else:
                    env.sim.model.geom_matid[geom_id] = 1
                    env.sim.model.geom_rgba[geom_id] = distinct_colors[i]
        if env_name == "maze2d-umaze-v1":
            env = env.env
            env.reward_type = 'sparse'

    if 'pca_pkl' in config:
        with open(config['pca_pkl'], 'rb') as f:
            pca = pickle.load(f)

    return env

def get_proprio(config, obs) -> np.ndarray:
    is_robosuite = config.get('robosuite', False)
    is_robotwin = config.get('robotwin', False)
    if is_robosuite:
        proprio_obs = np.array([])

        default_low_dim_obs = [
            "robot0_eef_pos",
            "robot0_eef_quat",
            "robot0_gripper_qpos",
        ]

        for key in default_low_dim_obs:
            proprio_obs = np.hstack((proprio_obs, obs[key]))

        return proprio_obs
    elif is_robotwin:
        prop_obs = obs['endpose']
        temp_prop_obs = []
        for key in ['left_endpose', 'left_gripper', 'right_endpose', 'right_gripper']:
            temp_prop_obs.append(np.array(prop_obs[key]).reshape(-1))
        return np.concatenate(temp_prop_obs).astype(np.float32)
    else:
        return obs

def reset_vision_ob():
    global vision_ob, trackers
    vision_ob = torch.tensor([], device=device)
    trackers = {}

def hide_robot(model):
    for i, name in enumerate(model.geom_names):
        if name.startswith('robot') or name.startswith('gripper') or name.startswith('mount'):
            rgba = model.geom_rgba[i].copy()
            rgba[3] = 0
            model.geom_rgba[i] = rgba

#@profile
def get_processed_obs(observation, frame, env, model, config, obs_type, numpy_action=True, is_first_ob=False):
    device = config['device']
    env_name = config.get('name', 10)
    cam_names = config.get("cams", [])
    is_robosuite = config.get('robosuite', False)

    if not isinstance(observation, list):
        observation = [observation]

    proprio_state = []
    for o in observation:
        if config.get('add_proprio', False):
            proprio_state.append(get_proprio(config, o))
        else:
            proprio_state.append(np.array([]))

    proprio_state = np.array(proprio_state)

    match obs_type:
        case 'state':
            return torch.tensor(crop_obs_for_env(observation, env_name, env_instance=env), device=device, dtype=torch.float32)
        case 'proprio':
            return torch.tensor(crop_obs_for_env(observation, env_name, env_instance=env, proprio=True), device=device, dtype=torch.float32)
        case 'dino':
            if not isinstance(frame, list):
                frame = [frame]
                batch_size = 1
            else:
                batch_size = len(frame)

            frame = np.array(frame)

            assert frame.shape[2] % 224 == 0
            num_viewpoints = frame.shape[2] // 224
            split_frames = np.stack(np.split(frame, num_viewpoints, axis=2), axis=1).reshape((batch_size * num_viewpoints, 224, 224, 3))
            image_features = model.dino.frames_to_dino(split_frames).reshape((batch_size, DINO_SIZE * num_viewpoints))

            return torch.hstack((torch.tensor(proprio_state, device=device, dtype=torch.float32), image_features))
        case 'r3m':
            if not isinstance(frame, list):
                frame = [frame]
                batch_size = 1
            else:
                batch_size = len(frame)

            frame = np.array(frame)

            assert frame.shape[2] % 224 == 0
            num_viewpoints = frame.shape[2] // 224
            split_frames = np.stack(np.split(frame, num_viewpoints, axis=2), axis=1).reshape((batch_size * num_viewpoints, 224, 224, 3))
            image_features = model.r3m.frames_to_r3m(split_frames).reshape((batch_size, RESNET_SIZE * num_viewpoints))

            return torch.hstack((torch.tensor(proprio_state, device=device, dtype=torch.float32), image_features))
        case 'vae':
            if not isinstance(frame, list):
                frame = [frame]
                batch_size = 1
            else:
                batch_size = len(frame)

            frame = np.array(frame)

            assert frame.shape[2] % 224 == 0
            num_viewpoints = frame.shape[2] // 224
            split_frames = np.stack(np.split(frame, num_viewpoints, axis=2), axis=1).reshape((batch_size, num_viewpoints, 224, 224, 3))

            latent_dims = [0]
            for i in range(num_viewpoints):
                latent_dims.append(model.vaes[i].latent_dim)

            latent_dims = torch.cumsum(torch.as_tensor(latent_dims), dim=0)

            image_features = torch.empty((batch_size, latent_dims[-1]), device=device)
            for i in range(num_viewpoints):
                image_features[:, latent_dims[i]:latent_dims[i+1]] = model.vaes[i].frames_to_vae(split_frames[:, i])

            return torch.hstack((torch.tensor(proprio_state, device=device, dtype=torch.float32), image_features))
        case 'resnet':
            return torch.hstack((torch.tensor(proprio_state, device=device, dtype=torch.float32), model.resnet.frames_to_resnet(np.array(frame)).squeeze(0)))
        case 'keypoint' | 'semantic_keypoint':
            return frame_to_keypoints(env_name, frame, env, is_robosuite=is_robosuite, is_first_ob=is_first_ob, proprio_state=proprio_state, cam_names=cam_names, semantic=(obs_type == "semantic_keypoint"))
        case 'rgb':
            return torch.as_tensor(np.hstack([proprio_state, cv2.resize(frame, (84, 84)).flatten()]), device=device, dtype=torch.float32)
        case 'rgb_dino':
            if not isinstance(frame, list):
                frame = [frame]
                batch_size = 1
            else:
                batch_size = len(frame)

            frame = np.array(frame)

            assert frame.shape[2] % 224 == 0
            num_viewpoints = frame.shape[2] // 224
            split_frames = np.array(np.split(frame, num_viewpoints, axis=2))  # (num_viewpoints, batch_size, 224, 224, 3)
            split_frames = split_frames.reshape(num_viewpoints, batch_size, -1).transpose(1, 0, 2).reshape(batch_size, -1)
            return torch.as_tensor(np.hstack([proprio_state, split_frames]), device=device, dtype=torch.float32)

#@profile
def get_action_from_obs(config, model, env, observation, frame, obs_history=None, numpy_action=True, is_first_ob=False):
    global vision_ob

    if is_first_ob:
        reset_vision_ob()

    if config.get('mixed'):
        obs = {}
        processed_obs_types = {}
        for dataset in ['retrieval', 'delta_state']:
            obs_type = config[dataset]['type']
            if obs_type not in processed_obs_types.keys():
                obs[dataset] = get_processed_obs(observation, frame, env, model, config, obs_type, numpy_action=numpy_action, is_first_ob=is_first_ob)

                processed_obs_types[obs_type] = dataset
            else:
                obs[dataset] = (obs[processed_obs_types[obs_type]]).detach().clone()
    else:
        obs_type = config['type']
        
        obs = get_processed_obs(observation, frame, env, model, config, obs_type, numpy_action=numpy_action, is_first_ob=is_first_ob)

    assert obs is not None

    if hasattr(model, "get_action"):
        action = model.get_action(obs, curr_rgb_obs=cv2.resize(frame, (224, 224), cv2.INTER_AREA).flatten()).squeeze()
    else:
        action = model(obs.unsqueeze(0)).squeeze(0)

    return action.cpu().detach().numpy()

#@profile
def get_action_from_obs_batched(config, model, envs, observations, frames, obs_history=None, numpy_action=True, is_first_ob=False):
    global vision_ob

    if is_first_ob:
        reset_vision_ob()

    if config.get('mixed'):
        obs = {}
        processed_obs_types = {}
        for dataset in ['retrieval', 'delta_state']:
            obs_type = config[dataset]['type']
            if obs_type not in processed_obs_types.keys():
                obs[dataset] = get_processed_obs(observations, frames, envs, model, config, obs_type, numpy_action=numpy_action, is_first_ob=is_first_ob)

                processed_obs_types[obs_type] = dataset
            else:
                obs[dataset] = (obs[processed_obs_types[obs_type]]).detach().clone()

        obs = torch.hstack((obs['retrieval'], obs['delta_state']))
    else:
        obs_type = config['type']
        
        obs = get_processed_obs(observations, frames, envs, model, config, obs_type, numpy_action=numpy_action, is_first_ob=is_first_ob)

    if hasattr(model, "get_action"):
        actions = model.get_action(obs, curr_rgb_obs=cv2.resize(frame, (224, 224), cv2.INTER_AREA).flatten()).squeeze()
    else:
        if obs_history is not None:
            if obs_history.shape[2] == 0:
                obs_history = torch.empty((obs_history.shape[0], 0, obs.shape[-1]), device=obs_history.device)
            obs_history = torch.cat((obs_history, obs.unsqueeze(1)), dim=1)
            actions = model(obs_history)
        else:
            actions = model(obs)

    return actions.cpu().detach().numpy(), obs_history

def get_keypoint_viz(cam_names):
    global keypoint_viz

    tracks = np.concatenate([x['tracks'][0] for x in keypoint_viz], axis=1)
    visibles = np.concatenate([x['visibles'][0] for x in keypoint_viz], axis=1)

    return tracks, visibles

def rgb_to_features(rgb_data, proprio_data, featurizer):
    obs_matrix, act_matrix, traj_starts = create_matrices(rgb_data, use_torch=True)
    flattened_obs_matrix = torch.cat([torch.as_tensor(obs, dtype=torch.uint8) for obs in obs_matrix], dim=0)
    all_images = flattened_obs_matrix.view(-1, 224, 224, 3).permute(0, 3, 1, 2).to(device)

    img_features = np.empty((len(all_images), 512))
    n_samples = all_images.shape[0]
    batch_size = 1024
    n_batches = (n_samples + batch_size - 1) // batch_size
    with torch.no_grad():
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, n_samples)

            images = all_images[start_idx:end_idx] / 255.0

            features = featurizer((images - torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)) / torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1))

            img_features[start_idx:end_idx] = features.squeeze(2).squeeze(2).detach().cpu().numpy()

    prop_obs_matrix, _, _ = create_matrices(proprio_data, use_torch=True)
    flattened_prop_obs_matrix = torch.cat([torch.as_tensor(obs) for obs in prop_obs_matrix], dim=0)

    i = 0
    for traj in rgb_data:
        traj["observations"] = np.hstack([
            flattened_prop_obs_matrix[i:i + len(traj["observations"])].detach().cpu().numpy(),
            img_features[i:i + len(traj["observations"])]
        ])
        i += len(traj["observations"])

    return rgb_data

#@profile
def stack_with_previous(obs_list, stack_size):
    if len(obs_list) < stack_size:
        return torch.cat([obs_list[0].unsqueeze(0).repeat(stack_size - len(obs_list), 1), obs_list], dim=0)
    return torch.cat([obs_list[-stack_size:]], dim=0)

#@profile
def frame_to_obj_centric_dino(env_name, rgb_array, proprio_state=np.array([]), numpy_action=True):
    global proprio_tensor_cpu, trackers, last_boxes, vision_ob
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if len(proprio_tensor_cpu) == 0:
        proprio_tensor_cpu = torch.empty(len(proprio_state), dtype=torch.float64, device='cpu', pin_memory=True)
    proprio_tensor_cpu.copy_(torch.from_numpy(proprio_state))

    obs = proprio_tensor_cpu.to(device, non_blocking=True)

    if len(vision_ob) == 0:
        if len(trackers) == 0:
            frame_and_box = get_semantic_frame_and_box("", env_name, rgb_array)
            trackers = {}
            for (_, box, obj) in frame_and_box:
                # Center to top left
                box[0] -= box[2] / 2
                box[1] -= box[3] / 2

                trackers[obj] = cv2.TrackerCSRT_create()
                trackers[obj].init(rgb_array, box.to(torch.uint8).tolist())
                last_boxes[obj] = box

        else:
            frame_and_box = []
            for obj in trackers.keys():
                success, box = trackers[obj].update(rgb_array)

                if not success:
                    box = last_boxes[obj]

                frame_and_box.append((rgb_array, torch.tensor(box, device=device), obj))

        for (frame, box, obj) in frame_and_box:
            box_xy, last_box_xy = box[:2], last_boxes[obj][:2]
            diff = box_xy - last_box_xy
            dots = torch.sum(box_xy * last_box_xy)

            norm_a = torch.sqrt(torch.sum(box_xy**2))
            norm_b = torch.sqrt(torch.sum(last_box_xy**2))

            cos_theta = dots / (norm_a * norm_b)
            cos_theta = torch.clamp(cos_theta, -1.0, 1.0)

            angle_and_mag = torch.stack([
                torch.arccos(cos_theta),
                torch.sqrt(torch.sum(diff**2))
            ])

            last_boxes[obj] = box
            dino_features = frame_to_dino(frame, numpy_action=False)
            #vision_ob = torch.hstack([vision_ob, box_xy, angle_and_mag, dino_features])
            vision_ob = torch.hstack([vision_ob, box_xy, dino_features])

    obs = torch.hstack([obs, vision_ob])

    return obs

#@profile
def env_to_rgb_array(env, camera, crop_corners, width, height):
    crop_width = crop_corners[1][0] - crop_corners[0][0]
    render_width = width / crop_width

    crop_height = crop_corners[1][1] - crop_corners[0][1]
    render_height = height / crop_height

    render_size = max(render_width, render_height)

    frame = env.render(mode='rgb_array', height=round(render_size), width=round(render_size), camera_name=camera)
    assert frame is not None

    crop_corners[:, 0] *= render_size
    crop_corners[:, 1] *= render_size
    crop_corners = np.round(crop_corners).astype(np.uint16)
    cropped_frame = frame[crop_corners[0][1]:crop_corners[1][1], crop_corners[0][0]:crop_corners[1][0], :]
    return cv2.resize(cropped_frame, (height, width))

def eval_over(steps, config, env_instance):
    env_name = config['name']
    is_metaworld = config.get('metaworld', False)
    is_robosuite = config.get('robosuite', False)

    return (is_metaworld and steps >= 1000
        or env_name == "push_t" and steps >= 200
        or env_name == "hopper-expert-v2" and steps >= 1000
        #or is_robosuite and steps >= 200
        or env_name == "maze2d-umaze-v1" and (np.linalg.norm(env_instance._get_obs()[0:2] - env_instance._target) <= 0.5 or steps >= 500)
        #or steps > 1 # For debugging
        or (env_name == "Stack_D0" and steps >= 200)
        or (env_name == "Square_D0" and steps >= 200)
        or (env_name == "StackThree_D0" and steps >= 350)
        or (env_name == "ThreePieceAssembly_D0" and steps >= 400)
        or (env_name == "Threading_D0" and steps >= 300))

def crop_obs_for_env(obs, env, env_instance=None, proprio=False):
    if env == "ant-expert-v2":
        return np.array(obs)[:, :27]
    elif env == "coffee-pull-v2" or env == "coffee-push-v2":
        return np.concatenate((obs[:11], obs[18:29], obs[-3:len(obs)]))
    elif env == "button-press-topdown-v2":
        return np.concatenate((obs[:9], obs[18:27], obs[-2:len(obs)]))
    elif env == "drawer-close-v2":
        return np.concatenate((obs[:7], obs[18:25], obs[-3:len(obs)]))
    # elif env == "push_t":
    #     return np.array([o[0] for o in obs])
    elif env == "Square_D1" or env == "Stack_D0" or env == "PickAndPlace_D0" or env == 'Threading_D0' or env == "ThreePieceAssembly_D0" or env == "StackThree_D0" or env == "Square_D0":
        default_low_dim_obs = [
                "robot0_eef_pos",
                "robot0_eef_quat",
                "robot0_gripper_qpos",
                "object",
        ]

        all_obs = []
        for o in obs:
            ret_obs = np.array([])

            for key in default_low_dim_obs:
                if key == "object" and proprio:
                    continue

                ret_obs = np.hstack((ret_obs, o[key]))
            all_obs.append(ret_obs)

        return np.array(all_obs).astype(np.float32)
    # elif env == "maze2d-umaze-v1":
    #     targets = []
    #     for maze_env in env_instance:
    #         targets.append(maze_env._target)
    #
    #     targets = np.array(targets)
    #     return np.hstack((targets, obs))
    else:
        return np.array(obs)

#@partial(jax.jit, static_argnums=())
def fast_2d_angles_and_magnitudes_jax(tracks_r, last_tracks_r):
    diff = tracks_r - last_tracks_r
    dots = jnp.sum(tracks_r * last_tracks_r, axis=1)

    norms_squared_a = jnp.sum(tracks_r**2, axis=1)
    norms_squared_b = jnp.sum(last_tracks_r**2, axis=1)

    cos_theta = dots / jnp.sqrt(norms_squared_a * norms_squared_b)
    cos_theta = jnp.clip(cos_theta, -1.0, 1.0)

    return jnp.stack([
        jnp.arccos(cos_theta),
        jnp.sqrt(jnp.sum(diff**2, axis=1))
    ])

def get_query_points_semantic(camera, env_name, frame):
    global grounding_dino_model
    if grounding_dino_model == None:
        grounding_dino_model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "GroundingDINO/weights/groundingdino_swint_ogc.pth")

    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    query_points = []
    if env_name == 'Stack_D0':
        frame, _ = transform(Image.fromarray(frame), None)
        boxes, logits, phrases = predict(
            model=grounding_dino_model,
            image=frame,
            caption="red block . green block",
            box_threshold=BOX_TRESHOLD,
            text_threshold=TEXT_TRESHOLD
        )
        boxes_dict = {}
        for i, phrase in enumerate(phrases):
            if phrase in boxes_dict.keys():
                boxes_dict[phrase].append(i)
            else:
                boxes_dict[phrase] = [i]

        for phrase in sorted(boxes_dict.keys()):
            indices = boxes_dict[phrase]

            # Find the most confident bounding box for this phrase
            x, y, width, height = boxes[indices[np.argmax(logits[indices])]] * 256
            
            query_points.append(np.hstack(([0], y, x)))

        if len(query_points) != 2:
            print("Couldn't find all objects! Returning junk data")
            query_points = np.array(
                [
                    [0, 0, 0],
                    [0, 0, 0],
                ]
            )

        return np.array(query_points)
    
#@profile
def get_semantic_frame_and_box(camera, env_name, frame):
    global grounding_dino_model, last_grounding_results
    if grounding_dino_model == None:
        grounding_dino_model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "GroundingDINO/weights/groundingdino_swint_ogc.pth").to(device)

    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    frame_and_box = []
    if env_name == 'Stack_D0':
        objects = ["red block", "green block"]
        image, _ = transform(Image.fromarray(frame), None)
        text_prompt = " . ".join(objects)
        boxes, logits, phrases = predict(
            model=grounding_dino_model,
            image=image,
            caption=text_prompt,
            box_threshold=BOX_TRESHOLD,
            text_threshold=TEXT_TRESHOLD,
            device=device
        )
        boxes_dict = {}
        for i, phrase in enumerate(phrases):
            if phrase in boxes_dict.keys():
                boxes_dict[phrase].append(i)
            else:
                boxes_dict[phrase] = [i]

        for obj in objects:
            if obj in boxes_dict:
                indices = boxes_dict[obj]

                # Find the most confident bounding box for this phrase
                best_box = boxes[indices[np.argmax(logits[indices])]] * 256
                x, y, width, height = best_box
                corners = np.array([
                    [x - width/2, y - height/2],
                    [x + width/2, y + height/2],
                ], dtype=np.uint8)
                if corners[0][0] == corners[1][0] or corners[0][1] == corners[1][1]:
                    frame_and_box.append(last_grounding_results[camera][obj])
                else:
                    frame_and_box.append([crop_and_resize(frame, corners), torch.asarray(best_box, device=device), obj])

                    # Cache if next hit fails
                    if not camera in last_grounding_results:
                        last_grounding_results[camera] = {}
                    last_grounding_results[camera][obj] = frame_and_box[-1]
            else:
                # No results this frame for this object, return last
                frame_and_box.append(last_grounding_results[camera][obj])

        assert len(frame_and_box) == 2
        return frame_and_box

def get_query_points(camera, env_name, env):
    query_points = []
    if env_name == 'Stack_D0':
        if camera == "sideview":
            query_points = np.array(
                [
                    np.hstack(([0], get_object_pixel_coords(env.env.sim, "cubeA_g0", camera_name=camera))),
                    np.hstack(([0], get_object_pixel_coords(env.env.sim, "cubeB_g0", camera_name=camera))),
                ]
            )
        elif camera == "frontview":
            query_points = np.array(
                [
                    np.hstack(([0], get_object_pixel_coords(env.env.sim, "cubeA_g0", camera_name=camera))),
                    np.hstack(([0], get_object_pixel_coords(env.env.sim, "cubeB_g0", camera_name=camera))),
                ]
            )
        elif camera == "agentview":
            query_points = np.array(
                [
                    np.hstack(([0], get_object_pixel_coords(env.env.sim, "cubeA_g0", camera_name=camera))),
                    np.hstack(([0], get_object_pixel_coords(env.env.sim, "cubeB_g0", camera_name=camera))),
                ]
            )
    elif env_name == 'Square_D0':
        if camera == "agentview" or camera == "sideview":
            query_points = np.array(
                [
                    np.hstack(([0], get_object_pixel_coords(env.env.sim, "SquareNut_g0", camera_name=camera, offset=np.array([0, 0, 1.0]), obj_size_ratio=True))),
                    np.hstack(([0], get_object_pixel_coords(env.env.sim, "SquareNut_g1", camera_name=camera, offset=np.array([0, 0, 1.0]), obj_size_ratio=True))),
                    np.hstack(([0], get_object_pixel_coords(env.env.sim, "SquareNut_g2", camera_name=camera, offset=np.array([0, 0, 1.0]), obj_size_ratio=True))),
                    np.hstack(([0], get_object_pixel_coords(env.env.sim, "SquareNut_g3", camera_name=camera, offset=np.array([0, 0, 1.0]), obj_size_ratio=True))),
                    np.hstack(([0], get_object_pixel_coords(env.env.sim, "SquareNut_g4", camera_name=camera, offset=np.array([0, 0, 1.0]), obj_size_ratio=True))),
                ]
            )
    elif env_name == 'Coffee_D0':
        #if camera == "agentview":
        query_points = np.array(
            [
                np.hstack(([0], get_object_pixel_coords(env.env.sim, "coffee_pod_g0", camera_name=camera, offset=np.array([0, 0, 0.0]), obj_size_ratio=True))),
            ]
        )

    if len(query_points) == 0:
        print("No query points found!")
    else:
        #print(query_points)
        pass
    return query_points

#@profile
def frame_to_keypoints(env_name, frame, env, is_robosuite=False, is_first_ob=False, proprio_state=[], cam_names=[], semantic=False):
    global tapir, query_features, causal_state, online_model_init, online_model_predict, last_tracks, keypoint_viz, frame_tensor_cpu, ret_tensor, proprio_tensor_cpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if tapir is None:
        init_tapir()

    height, width = 256, 256

    #frame[camera] = process_rgb_array_keypoints(frame[camera])
    if is_first_ob:
        all_query_points = np.empty((0, 3))
        for i, camera in enumerate(cam_names):
            if semantic:
                query_points = get_query_points_semantic(camera, env_name, frame[:, (i * width):((i + 1) * width), :])
            else:
                query_points = get_query_points(camera, env_name, env)
            query_points[:, 2] += i * width
            all_query_points = np.vstack((all_query_points, query_points))

        pickle.dump((frame, all_query_points), open("jax_test.pkl", 'wb'))
        query_features = online_model_init(frame, all_query_points[None])
        causal_state = tapir.construct_initial_causal_state(
            all_query_points.shape[0], len(query_features.resolutions) - 1
        )
        #for array in query_features.lowres:
            #print(hashlib.sha256(array.tobytes()).hexdigest())
        last_tracks = []
        keypoint_viz = []

        proprio_tensor_cpu = torch.empty(len(proprio_state), dtype=torch.float64, device='cpu', pin_memory=True)
        ret_tensor = torch.empty(len(proprio_state) + len(all_query_points) * 4, device=device, dtype=torch.float64)


    tracks, visibles, causal_state = online_model_predict(
        frames=frame,
        query_features=query_features,
        causal_context=causal_state,
    )
    keypoint_viz.append({'frame': frame, 'tracks': tracks, 'visibles': visibles})

    tracks_flat = tracks.reshape(-1)
    tracks_r = tracks.reshape(-1, 2)

    if len(last_tracks) > 0:
        results = fast_2d_angles_and_magnitudes_jax(tracks_r, last_tracks.reshape(-1, 2))
        np_results = np.array(results)
        angles, magnitudes = torch.from_numpy(np_results)
    else:
        angles = torch.zeros(len(tracks_r))
        magnitudes = torch.zeros(len(tracks_r))

    last_tracks = tracks_r

    proprio_tensor_cpu.copy_(torch.from_numpy(proprio_state))
    ret_tensor[:len(proprio_state)] = proprio_tensor_cpu.to(device, non_blocking=True)
    ret_tensor[len(proprio_state):len(proprio_state) + len(tracks_r) * 2] = torch.from_numpy(np.array(tracks_flat))
    ret_tensor[len(proprio_state) + len(tracks_r) * 2:len(proprio_state) + len(tracks_r) * 3] = angles
    ret_tensor[len(proprio_state) + len(tracks_r) * 3:len(proprio_state) + len(tracks_r) * 4] = magnitudes
    return ret_tensor

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cuda.enable_flash_sdp(False)
    torch.use_deterministic_algorithms(True, warn_only=True)

def load_expert_data(path):
    with open(path, 'rb') as input_file:
        return pickle.load(input_file)

def save_expert_data(data, path):
    if not os.path.exists(path):
        with open(path, 'wb') as output_file:
            return pickle.dump(data, output_file)

def create_matrices(expert_data, use_torch=False):
    obs_matrix = []
    act_matrix = []
    traj_starts = []

    idx = 0
    for traj in expert_data:
        # We will eventually be flattening all trajectories into a single list,
        # so keep track of trajectory start indices
        traj_starts.append(idx)
        idx += len(traj['observations'])

        # Create matrices for all observations and actions where each row is a trajectory
        # and each column is an single state or action within that trajectory
        if use_torch:
            obs_matrix.append(torch.as_tensor(traj['observations']))
            act_matrix.append(torch.as_tensor(traj['actions']))
        else:
            obs_matrix.append(traj['observations'])
            act_matrix.append(traj['actions'])

    if use_torch:
        traj_starts = torch.as_tensor(traj_starts)
    else:
        traj_starts = np.asarray(traj_starts)
    return obs_matrix, act_matrix, traj_starts

def compute_accum_distance_torch(nearest_neighbors, max_lookbacks, obs_history, sequence_lengths, flattened_obs_matrix, decay_factors):
    b, m = nearest_neighbors.shape
    device = nearest_neighbors.device
    max_seq_len = max_lookbacks.max().item()

    seq_indices = torch.arange(max_seq_len, device=device).flip(0).view(1, 1, -1)
    gather_indices = torch.maximum((nearest_neighbors).unsqueeze(2) - seq_indices, torch.tensor(0))
    matrix_slices = flattened_obs_matrix[gather_indices]

    obs_indices = obs_history.shape[1] - 1 - seq_indices
    obs_expanded = obs_history[torch.arange(b, device=device).view(b, 1, 1).expand(-1, m, max_seq_len), obs_indices]

    valid_mask = (seq_indices < max_lookbacks.unsqueeze(2)) & (seq_indices < sequence_lengths.view(b, 1, 1))
    feature_mask = valid_mask.unsqueeze(3)
    
    diff = (obs_expanded - matrix_slices) * feature_mask
    distances = torch.sqrt(torch.sum(diff * diff, dim=3))
    distances[torch.isnan(distances)] = 0
    # if len(torch.where(torch.isnan(obs_history))[0]) > 0:
    #     print(gather_indices[0])
    #     print(obs_indices[0])
    #     print(distances[0])
    #     print(valid_mask[0])
    #     reakpoint()
    weighted_distances = distances * decay_factors[-max_seq_len:].view(1, 1, -1) * valid_mask

    return torch.sum(weighted_distances, dim=2)

def compute_distance_with_rot_torch(curr_ob: torch.Tensor, flattened_obs_matrix: torch.Tensor, rot_weights: torch.Tensor):
    delta = torch.abs(curr_ob - flattened_obs_matrix)
    wrapped_delta = torch.min(delta, 2 * torch.pi - delta) / (2 * torch.pi)
    neighbor_vec_distances = wrapped_delta * rot_weights
    
    squared_dists = torch.sum(neighbor_vec_distances ** 2, dim=1)
    
    neighbor_distances = torch.sqrt(squared_dists)
    
    return neighbor_distances

def compute_distance_torch(curr_ob: torch.Tensor, flattened_obs_matrix: torch.Tensor):
    return torch.sqrt(torch.sum(torch.pow(torch.subtract(flattened_obs_matrix.unsqueeze(0), curr_ob.unsqueeze(1)), 2), dim=2))
