import json
import logging
import os
import pickle

from onpolicy.debug import debug_print
from torch import autograd
from einops import rearrange, reduce, repeat
import torch
import torch.nn as nn
import torch.distributed as dist
import numpy as np
import h5py

import psutil
import os
import time
# import jax
# import jax.numpy as jnp
# import flashbax as fbx
# from flashbax.vault import Vault

class D4RL_Dataset:
    def __init__(self, scenario_name, window_size, normalization_path, dataset_path, **kwargs):
        import gym
        import d4rl
        import mujoco_py
        import d4rl.gym_mujoco
        
        self._env = gym.make(scenario_name)
        self._dataset = self._env.get_dataset()#h5path=dataset_path)
        self._size = self._dataset["observations"].shape[0]
        self._window_size = window_size
        self.obs = torch.from_numpy(self._dataset["observations"]).to(dtype=torch.float32, device=torch.device('cuda'))
        self.actions = torch.from_numpy(self._dataset["actions"]).to(dtype=torch.float32, device=torch.device('cuda'))
        self.valid_idx = []
        self.normalize = normalization_path is not None
        if self.normalize:
            normalization = np.load(normalization_path)
            self.obs_min = torch.from_numpy(normalization["obs_min"]).to(dtype=torch.float32, device=torch.device('cuda'))
            self.obs_max = torch.from_numpy(normalization["obs_max"]).to(dtype=torch.float32, device=torch.device('cuda'))
            self.action_min = torch.from_numpy(normalization["action_min"]).to(dtype=torch.float32, device=torch.device('cuda'))
            self.action_max = torch.from_numpy(normalization["action_max"]).to(dtype=torch.float32, device=torch.device('cuda'))
        for i in range(self._size - self._window_size):
            ok = True
            for j in range(self._window_size):
                if self._dataset["timeouts"][i + j] == 1 or self._dataset["terminals"][i + j] == 1:
                    ok = False
                    break
            if ok:
                self.valid_idx.append(i)
        self.obs = self.normalize_obs(self.obs)
        self.actions = self.normalize_action(self.actions)
        debug_print(self.obs.shape, self.actions.shape)
    
    def __len__(self):
        return len(self.valid_idx)
    
    
    def normalize_obs(self, obs):
        if not self.normalize:
            return obs
        return 2 * ((obs - self.obs_min) / (self.obs_max - self.obs_min + 1e-6) - 0.5)
    
    def normalize_action(self, action):
        if not self.normalize:
            return action
        return 2 * ((action - self.action_min) / (self.action_max - self.action_min + 1e-6) - 0.5)
    
    def __getitem__(self, index):
        idx = self.valid_idx[index]
        return self.obs[idx:idx+self._window_size], self.actions[idx:idx+self._window_size]
    
    def sample(self, batch_size):
        indices = np.random.randint(0, self._dataset["observations"].shape[0] - self._window_size, (batch_size,))
        observations = []
        actions = []
        # debug_print(self._dataset["observations"].shape, self._dataset["actions"].shape)
        masks = np.ones((batch_size,))
        for w in range(self._window_size):
            cur_obs = self._dataset["observations"][indices + w] * masks[:, np.newaxis]
            cur_actions = self._dataset["actions"][indices + w] * masks[:, np.newaxis]
            observations.append(cur_obs)
            actions.append(cur_actions)
            timeouts = (1 - self._dataset["timeouts"][indices + w]) if "timeouts" in self._dataset else 1
            masks = masks * timeouts * (1 - self._dataset["terminals"][indices + w])
        observations = torch.from_numpy(np.stack(observations, axis=0)).to(dtype=torch.float32) # w, bs, d
        actions = torch.from_numpy(np.stack(actions, axis=0)).to(dtype=torch.float32) # w, bs, d
        return [observations, actions]

class Robomimic_Dataset:
    def __init__(self, scenario_name, window_size, dataset_path, act_step, split, normalization_path, **kwargs):
        load_path = dataset_path
        self.actions = []
        self.observations = []
        self._window_size = window_size
        
        normalization = np.load(normalization_path)
        self.obs_min = normalization["obs_min"]
        self.obs_max = normalization["obs_max"]
        self.action_min = normalization["action_min"]
        self.action_max = normalization["action_max"]
        with h5py.File(load_path, "r") as f:
            self._dataset = f
            demos = sorted(list(f["data"].keys()))
            inds = np.argsort([int(elem[5:]) for elem in demos])
            demos = [demos[i] for i in inds]
            debug_print(f['data/demo_1/obs'].keys())
            # print(demos[0])

            # Default low-dimensional observation keys
            low_dim_obs_names = [
                "robot0_eef_pos",
                "robot0_eef_quat",
                "robot0_gripper_qpos",
                # "object",
            ]
            if "transport" in load_path:
                low_dim_obs_names += [
                    "robot1_eef_pos",
                    "robot1_eef_quat",
                    "robot1_gripper_qpos",
                ]
            # if args.cameras is None:
            low_dim_obs_names.append("object")

            # Calculate dimensions for observations and actions
            obs_dim = 0
            for low_dim_obs_name in low_dim_obs_names:
                # print(low_dim_obs_name)
                dim = f[f"data/demo_0/obs/{low_dim_obs_name}"].shape[1]
                obs_dim += dim
                logging.info(f"Using {low_dim_obs_name} with dim {dim} for observation")

            action_dim = f["data/demo_0/actions"].shape[1]
            logging.info(f"Total low-dim observation dim: {obs_dim}")
            logging.info(f"Action dim: {action_dim}")

            # Initialize variables for tracking trajectory statistics
            traj_lengths = []
            # obs_min = np.zeros((obs_dim))
            # obs_max = np.zeros((obs_dim))
            # action_min = np.zeros((action_dim))
            # action_max = np.zeros((action_dim))

            # Process each demo
            for ep in demos:
                traj_lengths.append(f[f"data/{ep}/actions"].shape[0])
                obs = np.hstack(
                    [
                        f[f"data/{ep}/obs/{low_dim_obs_name}"][()]
                        for low_dim_obs_name in low_dim_obs_names
                    ]
                )
                steps = np.arange(obs.shape[0]) / 1000
                # obs = np.concatenate([obs, steps[:, np.newaxis]], axis=-1)
                # debug_print(obs.shape)
                actions = f[f"data/{ep}/actions"][()]
                # debug_print(actions.shape, obs.shape)
                self.actions.append(actions)
                self.observations.append(obs)
                # obs_min = np.minimum(obs_min, np.min(obs, axis=0))
                # obs_max = np.maximum(obs_max, np.max(obs, axis=0))
                # action_min = np.minimum(action_min, np.min(actions, axis=0))
                # action_max = np.maximum(action_max, np.max(actions, axis=0))

        
        # total_samples = self.actions.shape[0]
        # indices = np.random.permutation(total_samples)
        # split_idx = int(total_samples * 0.8)
        
        # if split == "train":
        #     train_indices = indices[:split_idx]
        #     self.actions = self.actions[train_indices]
        #     self.observations = self.observations[train_indices]
        # elif split == "val":
        #     val_indices = indices[split_idx:]
        #     self.actions = self.actions[val_indices]
        #     self.observations = self.observations[val_indices]

        # Replace the commented out split code with episode-based split
        
        # num_episodes = len(episode_lengths)
        num_episodes = len(demos)
        # np.random.seed(1)
        episode_indices = np.random.permutation(range(num_episodes))
        # print(episode_indices)
        split_idx = int(num_episodes * 0.9)

        if split == "train":
            selected_episodes = episode_indices[:split_idx]
        elif split == "val":
            selected_episodes = episode_indices[split_idx:]
        
        episode_lengths = []
        start_idx = 0
        self.index = []
        self.act_step = act_step
        action_list = []
        obs_list = []
        for ep in selected_episodes:
            length = self.actions[ep].shape[0]
            action_list.append(self.actions[ep])
            obs_list.append(self.observations[ep])
            # f[f"data/{ep}/actions"].shape[0]
            episode_lengths.append((start_idx, start_idx + length))
            for i in range(length - act_step):
                self.index.append(i + start_idx)
            start_idx += length

        # Randomly shuffle episode indices for splitting

        self.actions = np.concatenate(action_list, axis=0)
        self.observations = np.concatenate(obs_list, axis=0)
        # Create mask for selected episodes
        # mask = np.zeros(self.actions.shape[0], dtype=bool)
        # for ep_idx in selected_episodes:
        #     start, end = episode_lengths[ep_idx]
        #     mask[start:end] = True

        # Apply mask to actions and observations
        debug_print(self.observations.shape, self.obs_min.shape)
        self.actions = (self.actions - self.action_min) / (self.action_max - self.action_min + 1e-6) * 2 -1
        self.observations = (self.observations - self.obs_min) / (self.obs_max - self.obs_min + 1e-6) * 2 -1
        self.lower_action = self.actions.min()
        self.upper_action = self.actions.max()
        self.lower_obs = self.observations.min()
        self.upper_obs = self.observations.max()


        # self.actions = self.actions[mask]
        # self.observations = self.observations[mask]

        # Replace the simple split with random indices
        # total_samples = self.actions.shape[0]
        # indices = np.random.permutation(total_samples)
        # split_idx = int(total_samples * 0.8)
        
        # if split == "train":
        #     train_indices = indices[:split_idx]
        #     self.actions = self.actions[train_indices]
        #     self.observations = self.observations[train_indices]
        # elif split == "val":
        #     val_indices = indices[split_idx:]
        #     self.actions = self.actions[val_indices]
        #     self.observations = self.observations[val_indices]

        debug_print(self.actions.min(), self.actions.max(), self.observations.min(), self.observations.max())
        debug_print(self.lower_action, self.upper_action, self.lower_obs, self.upper_obs)
        # debug_print(self.obs_min, self.obs_max, self.action_min, self.action_max)
        # self.actions = (self.actions - self.lower_action) / (self.upper_action - self.lower_action + 1e-6) * 2 -1
        # self.observations = (self.observations - self.lower_obs) / (self.upper_obs - self.lower_obs + 1e-6) * 2 -1
        # debug_print('in')
        # self.actions = torch.tensor(self.actions, device=torch.device('cuda'))
        # self.observations = torch.tensor(self.observations, device=torch.device('cuda'))
        # debug_print('out')
    
    def __len__(self):
        return len(self.index)
    
    def __getitem__(self, index):
        index = self.index[index]
        # obs = self.observations[index:index+self.act_step].copy()
        # debug_print(obs.shape)
        # obs[:, -1] += np.random.rand(obs.shape[0]) * 20 / 1000
        # debug_print('item', index, self.observations.shape, self.actions.shape)
        return self.observations[index:index+self.act_step], self.actions[index:index+self.act_step]
        # debug_print('itemout')
        # return out
    # def sample(self, batch_size):
    #     # debug_print(self.observations.shape, self.actions.shape)
    #     indices = np.random.randint(0, self.observations.shape[0] - self._window_size, (batch_size,))
    #     observations = []
    #     actions = []
    #     masks = np.ones((batch_size,))
    #     for w in range(self._window_size):
    #         cur_obs = self.observations[indices + w] * masks[:, np.newaxis]
    #         cur_actions = self.actions[indices + w] * masks[:, np.newaxis]
    #         observations.append(cur_obs)
    #         actions.append(cur_actions)
    #         timeouts = 1
    #         # timeouts = (1 - self._dataset["timeouts"][indices + w]) if "timeouts" in self._dataset else 1
    #         # masks = masks * timeouts * (1 - self._dataset["terminals"][indices + w])
    #     observations = torch.from_numpy(np.stack(observations, axis=0)).to(dtype=torch.float32) # w, bs, d
    #     actions = torch.from_numpy(np.stack(actions, axis=0)).to(dtype=torch.float32) # w, bs, d
    #     return [observations, actions]
    
    
class Robomimic_img_Dataset:
    def __init__(self, scenario_name, window_size, dataset_path, act_step, split, image_names = ["agentview_image"], low_dim_keys = None, normalization_path = '/workspace/Diffusion-PPO/normalization.npz', **kwargs):
        load_path = dataset_path
        debug_print("Image names", image_names)
        self.actions = []
        self.observations = []
        self.low_dim_obs = []
        self._window_size = window_size
        normalization = np.load(normalization_path.replace("/normalization.npz", "-img/normalization.npz"))
        # normalization = np.load(normalization_path)
        self.obs_min = normalization["obs_min"]
        self.obs_max = normalization["obs_max"]
        self.action_min = normalization["action_min"]
        self.action_max = normalization["action_max"]
        with h5py.File(load_path, "r") as f:
            self._dataset = f
            demos = sorted(list(f["data"].keys()))
            inds = np.argsort([int(elem[5:]) for elem in demos])
            demos = [demos[i] for i in inds]
            debug_print(f['data/demo_0/obs'].keys())
            # if args.cameras is None:

            action_dim = f["data/demo_0/actions"].shape[1]
            logging.info(f"Action dim: {action_dim}")

            # Initialize variables for tracking trajectory statistics
            traj_lengths = []
            # obs_min = np.zeros((obs_dim))
            # obs_max = np.zeros((obs_dim))
            # action_min = np.zeros((action_dim))
            # action_max = np.zeros((action_dim))

            # Process each demo
            for ep in demos:
                traj_lengths.append(f[f"data/{ep}/actions"].shape[0])
                obs = np.hstack(
                    [
                        f[f"data/{ep}/obs/{image_name}"][()] for image_name in image_names
                    ]
                )
                low_dim_obs = np.hstack(
                    [
                        f[f"data/{ep}/obs/{low_dim_key}"][()] for low_dim_key in low_dim_keys if low_dim_key != "object"
                    ]
                )
                steps = np.arange(obs.shape[0]) / 1000
                actions = f[f"data/{ep}/actions"][()]
                self.actions.append(actions)
                self.observations.append(obs)
                self.low_dim_obs.append(low_dim_obs)
        
        # num_episodes = len(episode_lengths)
        num_episodes = len(demos)
        # np.random.seed(1)
        episode_indices = np.random.permutation(range(num_episodes))
        # print(episode_indices)
        split_idx = int(num_episodes * 0.9)

        if split == "train":
            selected_episodes = episode_indices[:split_idx]
        elif split == "val":
            selected_episodes = episode_indices[split_idx:]
        
        episode_lengths = []
        start_idx = 0
        self.index = []
        self.act_step = act_step
        action_list = []
        obs_list = []
        low_dim_obs_list = []
        for ep in selected_episodes:
            length = self.actions[ep].shape[0]
            action_list.append(self.actions[ep])
            obs_list.append(self.observations[ep])
            low_dim_obs_list.append(self.low_dim_obs[ep])
            # f[f"data/{ep}/actions"].shape[0]
            episode_lengths.append((start_idx, start_idx + length))
            for i in range(length - act_step):
                self.index.append(i + start_idx)
            start_idx += length

        # Randomly shuffle episode indices for splitting

        self.actions = np.concatenate(action_list, axis=0)
        self.observations = np.concatenate(obs_list, axis=0)
        self.low_dim_obs = np.concatenate(low_dim_obs_list, axis=0)
        # Create mask for selected episodes
        # mask = np.zeros(self.actions.shape[0], dtype=bool)
        # for ep_idx in selected_episodes:
        #     start, end = episode_lengths[ep_idx]
        #     mask[start:end] = True

        # Apply mask to actions and observations
        debug_print(self.observations.shape, self.obs_min.shape)
        self.observations = self.observations.transpose(0, 3, 1, 2)
        debug_print('fa', self.observations.shape, self.observations.dtype)
        print(self.observations.shape)
        debug_print(self.obs_max, self.obs_min, self.action_max, self.action_min)
        self.observations = self.observations.reshape(self.observations.shape[0], -1)
        self.actions = (self.actions - self.action_min) / (self.action_max - self.action_min + 1e-6) * 2 -1
        self.observations = torch.tensor(self.observations, device=torch.device('cuda'), dtype=torch.float32)
        self.actions = torch.tensor(self.actions, device=torch.device('cuda'))
        # debug_print('fa', self.observations.shape, self.actions.shape)
        # self.observations = (self.observations - self.obs_min) / (self.obs_max - self.obs_min + 1e-6) * 2 -1
        debug_print(self.low_dim_obs.shape, self.observations.shape)
        self.low_dim_obs = (self.low_dim_obs - self.obs_min) / (self.obs_max - self.obs_min + 1e-6) * 2 -1
        self.low_dim_obs = torch.tensor(self.low_dim_obs, device=torch.device('cuda'), dtype=torch.float32)
        self.observations = torch.cat([self.observations, self.low_dim_obs], dim=-1)

    def __len__(self):
        return len(self.index)
    
    def __getitem__(self, index):
        index = self.index[index]
        return self.observations[index:index+self.act_step], self.actions[index:index+self.act_step]
    

class Robomimic_Dataset_mh:
    def __init__(self, scenario_name, window_size, dataset_path, **kwargs):
        load_path = dataset_path
        self.actions = []
        self.observations = []
        self._window_size = window_size
        with h5py.File(load_path, "r") as f:
            self._dataset = f
            demos = sorted(list(f["data"].keys()))
            inds = np.argsort([int(elem[5:]) for elem in demos])
            demos = [demos[i] for i in inds]
            print(f['data/demo_1'].keys())
            # print(demos[0])

            # Default low-dimensional observation keys
            low_dim_obs_names = [
                "robot0_eef_pos",
                "robot0_eef_quat",
                "robot0_gripper_qpos",
            ]
            if "transport" in load_path:
                low_dim_obs_names += [
                    "robot1_eef_pos",
                    "robot1_eef_quat",
                    "robot1_gripper_qpos",
                ]
            # if args.cameras is None:
            #     low_dim_obs_names.append("object")

            # Calculate dimensions for observations and actions
            obs_dim = 0
            for low_dim_obs_name in low_dim_obs_names:
                # print(low_dim_obs_name)
                dim = f[f"data/demo_0/obs/{low_dim_obs_name}"].shape[1]
                obs_dim += dim
                logging.info(f"Using {low_dim_obs_name} with dim {dim} for observation")

            action_dim = f["data/demo_0/actions"].shape[1]
            logging.info(f"Total low-dim observation dim: {obs_dim}")
            logging.info(f"Action dim: {action_dim}")

            # Initialize variables for tracking trajectory statistics
            traj_lengths = []
            obs_min = np.zeros((obs_dim))
            obs_max = np.zeros((obs_dim))
            action_min = np.zeros((action_dim))
            action_max = np.zeros((action_dim))

            # Process each demo
            for ep in demos:
                traj_lengths.append(f[f"data/{ep}/actions"].shape[0])
                obs = np.hstack(
                    [
                        f[f"data/{ep}/obs/{low_dim_obs_name}"][()]
                        for low_dim_obs_name in low_dim_obs_names
                    ]
                )
                actions = f[f"data/{ep}/actions"][()]
                # debug_print(actions.shape, obs.shape)
                self.actions.append(actions)
                self.observations.append(obs)
                obs_min = np.minimum(obs_min, np.min(obs, axis=0))
                obs_max = np.maximum(obs_max, np.max(obs, axis=0))
                action_min = np.minimum(action_min, np.min(actions, axis=0))
                action_max = np.maximum(action_max, np.max(actions, axis=0))
                print(obs.shape, actions.shape)
        self.actions = np.concatenate(self.actions, axis=0)
        self.observations = np.concatenate(self.observations, axis=0)
    
    def sample(self, batch_size):
        # debug_print(self.observations.shape, self.actions.shape)
        indices = np.random.randint(0, self.observations.shape[0] - self._window_size, (batch_size,))
        observations = []
        actions = []
        masks = np.ones((batch_size,))
        for w in range(self._window_size):
            cur_obs = self.observations[indices + w] * masks[:, np.newaxis]
            cur_actions = self.actions[indices + w] * masks[:, np.newaxis]
            observations.append(cur_obs)
            actions.append(cur_actions)
            timeouts = 1
            # timeouts = (1 - self._dataset["timeouts"][indices + w]) if "timeouts" in self._dataset else 1
            # masks = masks * timeouts * (1 - self._dataset["terminals"][indices + w])
        observations = torch.from_numpy(np.stack(observations, axis=0)).to(dtype=torch.float32) # w, bs, d
        actions = torch.from_numpy(np.stack(actions, axis=0)).to(dtype=torch.float32) # w, bs, d
        return [observations, actions]

class Football_Dataset:
    def __init__(self, vault_uid):
        import jax
        import jax.numpy as jnp
        import flashbax as fbx
        from flashbax.vault import Vault

        vlt = Vault(rel_dir="/root/vaults", vault_name="gfootball", vault_uid=vault_uid)
        all_data = vlt.read()
        print(jax.tree_map(lambda x: x.shape, all_data.experience))
        batch_size = 256
        self.obs = np.array(all_data.experience['obs'])
        self.obs = self.obs.reshape(-1, *self.obs.shape[2:])
        self.actions = np.array(all_data.experience['actions'])
        self.actions = self.actions.reshape(-1, *self.actions.shape[2:])
        # self.mask = np.array(all_data.experience['mask'])
        self.reward = np.array(all_data.experience['reward'])
        self.reward = self.reward.reshape(-1, *self.reward.shape[2:])
        self.done = np.array(all_data.experience['done'])
        self.done = self.done.reshape(-1, *self.done.shape[2:])
        # obs = torch.tensor(obs)
        self.obs_mean = self.obs.reshape(-1, self.obs.shape[-1]).mean(axis=0)
        self.obs_std = self.obs.reshape(-1, self.obs.shape[-1]).std(axis=0)
        debug_print(self.obs.shape)
        self.actions = torch.tensor(self.actions, device=torch.device('cuda'))
        self.obs = torch.tensor(self.obs, device=torch.device('cuda'))
        self.reward = torch.tensor(self.reward, device=torch.device('cuda'))
        self.done = torch.tensor(self.done, device=torch.device('cuda'))
        # debug_print(self.obs_mean, self.obs_std)
    
    def __len__(self):
        return self.obs.shape[0]
    
    def __getitem__(self, index):
        return self.obs[index], self.actions[index], self.reward[index], self.done[index]