import numpy as np
import torch
import collections
import sys
import glob
import os
import re
import matplotlib.pyplot as plt
import pandas as pd

import numpy as np

def qlearning_dataset_c4(env, dataset=None, terminate_on_end=False, **kwargs):
    """
    Similar to the original qlearning_dataset, but returns richer fields and trajectory statistics.

    Returns:
        traj_length: [len(traj_0), len(traj_1), ...]  Number of transitions per trajectory (Note: counts sampled transitions, if terminate_on_end=False, the last step is skipped)
        episode_end_list: [i_0, i_1, ...]             Termination index i in the original dataset for each trajectory (corresponding to the "current step" of the i-th transition)
        data: {
            "start": np.array(start_),                # Whether this is the first transition of a trajectory
            "observations": np.array(obs_),
            "actions": np.array(action_),
            "next_observations": np.array(next_obs_),
            "next_actions": np.array(next_act_),      # Next action a_{t+1}
            "rewards": np.array(reward_),
            "terminals": np.array(done_),
            "timeouts": np.array(timeout_),           # Whether the current transition ended due to timeout
            "trajectory": np.array(traj_),            # Trajectory number (starting from 0)
            "step": np.array(step_),                  # Index i in the original dataset (aligned with obs[i])
            # Optional:
            # "qvel": np.array(qvel_)                 # If dataset contains 'qvel', aligned with current step
            # "qpos": np.array(qpos_)                 # If dataset contains 'qpos', aligned with current step
        }
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset['rewards'].shape[0]
    use_timeouts = 'timeouts' in dataset

    # Output buffers
    start_ = []
    obs_ = []
    next_obs_ = []
    action_ = []
    next_act_ = []
    reward_ = []
    done_ = []
    timeout_ = []
    traj_ = []
    step_ = []

    qvel_ = []
    qpos_ = []
    has_qvel = 'qvel' in dataset
    has_qpos = 'qpos' in dataset

    # Trajectory statistics
    traj_length = []          # Final number of transitions per trajectory
    episode_end_list = []     # Original termination index i for each trajectory

    episode_step = 0          # Step count within current trajectory
    curr_traj_len = 0         # Number of transitions collected in current trajectory
    traj_idx = -1             # Trajectory index (starting from -1, set to 0 when entering first trajectory)

    # Iterate to N-2 because we need i+1 as next
    for i in range(N - 1):
        # Read current/next step basic fields
        obs = dataset['observations'][i].astype(np.float32)
        new_obs = dataset['observations'][i + 1].astype(np.float32)
        act = dataset['actions'][i].astype(np.float32)
        next_act = dataset['actions'][i + 1].astype(np.float32)
        rew = dataset['rewards'][i].astype(np.float32)

        done_bool = bool(dataset['terminals'][i])
        if use_timeouts:
            final_timestep = bool(dataset['timeouts'][i])
        else:
            # fallback: infer using env._max_episode_steps
            final_timestep = (episode_step == env._max_episode_steps - 1)

        # If this is the first transition of a trajectory, start new trajectory index
        if episode_step == 0:
            traj_idx += 1
            start_flag = True
        else:
            start_flag = False

        # When terminate_on_end=False and final step (timeout or natural end):
        # Skip this transition, but still need to properly end current trajectory and record statistics
        if (not terminate_on_end) and final_timestep:
            # This transition is not written to data, but trajectory terminates at i
            episode_end_list.append(i)
            traj_length.append(curr_traj_len)  # Number of data entries already written
            # Reset and start new trajectory
            episode_step = 0
            curr_traj_len = 0
            continue

        # Normal write of this transition
        start_.append(start_flag)
        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(act)
        next_act_.append(next_act)
        reward_.append(rew)
        done_.append(done_bool)
        timeout_.append(final_timestep)
        traj_.append(traj_idx)
        step_.append(i)

        if has_qvel:
            qvel_.append(np.array(dataset['qvel'][i], dtype=np.float32))
        if has_qpos:
            qpos_.append(np.array(dataset['qpos'][i], dtype=np.float32))

        episode_step += 1
        curr_traj_len += 1

        # If trajectory ends at this step (terminal or timeout), finalize and reset
        if done_bool or final_timestep:
            episode_end_list.append(i)
            traj_length.append(curr_traj_len)
            episode_step = 0
            curr_traj_len = 0

    data = {
        "start": np.array(start_, dtype=bool),
        "observations": np.array(obs_, dtype=np.float32),
        "actions": np.array(action_, dtype=np.float32),
        "next_observations": np.array(next_obs_, dtype=np.float32),
        "next_actions": np.array(next_act_, dtype=np.float32),
        "rewards": np.array(reward_, dtype=np.float32),
        "terminals": np.array(done_, dtype=bool),
        "timeouts": np.array(timeout_, dtype=bool),
        "trajectory": np.array(traj_, dtype=np.int64),
        "step": np.array(step_, dtype=np.int64),
    }
    if has_qvel:
        data["qvel"] = np.array(qvel_, dtype=np.float32)
    if has_qpos:
        data["qpos"] = np.array(qpos_, dtype=np.float32)

    return traj_length, episode_end_list, data


def qlearning_dataset_all(env, task=None, dataset=None, terminate_on_end=False, num_traj=0, **kwargs):
    """
    Export all samples without discarding terminal steps
    Returns:
      traj_length         List of each trajectory length
      episode_end_list    Cumulative sample count up to the end of each trajectory
      data                Standard Q-learning format data dictionary
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    has_next_obs = "next_observations" in dataset
    has_next_act = "next_actions" in dataset
    use_timeouts = "timeouts" in dataset

    N = dataset["rewards"].shape[0]
    need_shift = (not has_next_obs) or (not has_next_act)
    N_iter = N - 1 if need_shift else N

    start_ = []
    obs_ = []
    next_obs_ = []
    action_ = []
    next_act_ = []
    reward_ = []
    done_ = []
    timeout_ = []
    traj_ = []
    step_ = []
    qvel_ = []
    qpos_ = []

    episode_step = 0
    traj_step = 0
    traj_count = 0

    traj_length = []
    episode_end_list = []

    for i in range(N_iter):
        obs = dataset["observations"][i].astype(np.float32)
        new_obs = (
            dataset["next_observations"][i].astype(np.float32)
            if has_next_obs else dataset["observations"][i + 1].astype(np.float32)
        )
        action = dataset["actions"][i].astype(np.float32)
        new_act = (
            dataset["next_actions"][i].astype(np.float32)
            if has_next_act else dataset["actions"][i + 1].astype(np.float32)
        )

        reward = dataset["rewards"][i].astype(np.float32)
        done_bool = bool(dataset["terminals"][i])

        # Skip collecting qpos qvel if they don't exist
        if "infos/qpos" in dataset:
            qpos = dataset["infos/qpos"][i].astype(np.float32)
        else:
            qpos = None
        if "infos/qvel" in dataset:
            qvel = dataset["infos/qvel"][i].astype(np.float32)
        else:
            qvel = None

        if use_timeouts:
            final_timestep = bool(dataset["timeouts"][i])
        else:
            # Fallback approximation
            final_timestep = done_bool

        is_end = done_bool or final_timestep

        # Write sample
        start_.append(episode_step == 0)
        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        next_act_.append(new_act)
        reward_.append(reward)
        done_.append(done_bool)
        timeout_.append(final_timestep if use_timeouts else done_bool)
        traj_.append(traj_count)
        step_.append(i)
        if qvel is not None:
            qvel_.append(qvel)
        if qpos is not None:
            qpos_.append(qpos)

        episode_step += 1
        traj_step += 1

        # Settle and truncate after episode ends
        if is_end:
            traj_length.append(episode_step)
            episode_end_list.append(traj_step)
            episode_step = 0
            traj_count += 1
            if num_traj > 0 and traj_count >= num_traj:
                break

    # Only supplement statistics if there's an unclosed tail episode at the end
    if episode_step > 0:
        traj_length.append(episode_step)
        episode_end_list.append(traj_step)

    data = {
        "start": np.array(start_),
        "observations": np.array(obs_),
        "actions": np.array(action_),
        "next_observations": np.array(next_obs_),
        "next_actions": np.array(next_act_),
        "rewards": np.array(reward_),
        "terminals": np.array(done_),
        "timeouts": np.array(timeout_),
        "trajectory": np.array(traj_),
        "step": np.array(step_),
    }
    if qvel_:
        data["qvel"] = np.array(qvel_)
    if qpos_:
        data["qpos"] = np.array(qpos_)

    return traj_length, episode_end_list, data




def qlearning_dataset(env, task=None, dataset=None, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    
    # Normalization
    # min_reward = dataset['rewards'].min()
    # max_reward = dataset['rewards'].max()
    # dataset['rewards'] = (dataset['rewards'] - min_reward) / (max_reward - min_reward)
    
    has_next_obs = True if 'next_observations' in dataset.keys() else False
    # More robust writing (handles case where task might be None)
    is_antmaze = bool(task) and 'antmaze' in str(task).lower()
    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []


    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True


    episode_step = 0
    
    traj_min, traj_max, traj_diff = 0, 0, 0
    traj_r_list = []
    traj_var_list = []
    traj_diff_list = []
    traj_list = []
    
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        
        


        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            if not is_antmaze:
                continue  
        if done_bool or final_timestep:
            traj_list.append(episode_step) 
            episode_step = 0
            
            traj_diff = traj_max - traj_min
            traj_diff_list.append(traj_diff)
            traj_min, traj_max, traj_diff = 0, 0, 0
            if (not has_next_obs) and (not is_antmaze):
                continue

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        episode_step += 1 
        
    


    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
    }

def qlearning_dataset_checktraj(env, task=None, dataset=None, terminate_on_end=False, num_traj=0, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    
    
    has_next_obs = True if 'next_observations' in dataset.keys() else False
    has_next_act = True if 'next_actions' in dataset.keys() else False
    is_antmaze = bool(task) and 'antmaze' in str(task).lower()

    N = dataset['rewards'].shape[0]
    start_ = []
    obs_ = []
    next_obs_ = []
    action_ = []
    next_act_ = []
    reward_ = []
    done_ = []
    timeout_ = []
    traj_ = []
    step_ = []
    qvel_ = []
    qpos_ = []
    # timeout_[0]

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True


    episode_step = 0
    traj_step = 0
    check_traj = 0
    check_traj1 = 0
    check_traj2 = 0
    traj_length = []
    episode_end_list = []
    
    
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        if has_next_act:
            new_act = dataset['next_actions'][i].astype(np.float32)
        else:
            new_act = dataset['actions'][i+1].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        qpos = dataset["infos/qpos"][i].astype(np.float32)
        qvel = dataset["infos/qvel"][i].astype(np.float32)



        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            traj_length.append(episode_step)
            episode_end_list.append(traj_step)
            if (len(timeout_) > 1) and (not use_timeouts):
                timeout_[-1] = True
            episode_step = 0
            check_traj += 1
            if not is_antmaze:
            # done_bool = False
                continue  
        if done_bool or final_timestep:
            traj_length.append(episode_step)
            episode_end_list.append(traj_step)
            # timeout_[-1] = True
            episode_step = 0
            check_traj += 1
            
            
            if (not has_next_obs) and (not is_antmaze):
                continue


        if episode_step == 0:
            start_.append(True)
        else:
            start_.append(False)
        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        next_act_.append(new_act)
        reward_.append(reward)
        done_.append(done_bool)
        traj_.append(check_traj)
        step_.append(i)
        qvel_.append(qvel)
        qpos_.append(qpos)
        if use_timeouts:
            timeout_.append(final_timestep)
        else:
            timeout_.append(done_bool)

        episode_step += 1
        traj_step += 1

        if num_traj > 0 and num_traj==check_traj:
            break
    
    traj_length.append(episode_step)
    episode_end_list.append(traj_step)



    return traj_length, episode_end_list,  {
        'start': np.array(start_),
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'next_actions': np.array(next_act_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
        'timeouts': np.array(timeout_),
        'trajectory': np.array(traj_),
        'step': np.array(step_),
        'qvel': np.array(qvel_),
        'qpos': np.array(qpos_),
    }


def qlearning_dataset_wo_timeout(env, dataset=None, terminate_on_end=False, num_traj=0, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    
    
    has_next_obs = True if 'next_observations' in dataset.keys() else False
    has_next_act = True if 'next_actions' in dataset.keys() else False

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    next_act_ = []
    reward_ = []
    done_ = []
    timeout_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    traj_step = 0
    check_traj = 0
    check_traj1 = 0
    check_traj2 = 0
    traj_length = []
    episode_end_list = []
    
    
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        if has_next_act:
            new_act = dataset['next_actions'][i].astype(np.float32)
        else:
            new_act = dataset['actions'][i+1].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])

    

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        final_timestep = False
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            traj_length.append(episode_step)
            episode_end_list.append(traj_step)
            timeout_[-1] = True
            episode_step = 0
            check_traj += 1
            continue  
        if done_bool or final_timestep:
            traj_length.append(episode_step)
            episode_end_list.append(traj_step)
            # timeout_[-1] = True
            episode_step = 0
            check_traj += 1
            
            
            if not has_next_obs:
                continue



        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        next_act_.append(new_act)
        reward_.append(reward)
        done_.append(done_bool)
        timeout_.append(done_bool)


        episode_step += 1
        traj_step += 1

        if num_traj > 0 and num_traj==check_traj:
            break
    
    traj_length.append(episode_step)
    episode_end_list.append(traj_step)



    return traj_length, episode_end_list,  {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'next_actions': np.array(next_act_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
        'timeouts': np.array(timeout_),
    }



def qlearning_dataset_traj(env, dataset=None, terminate_on_end=False, num_traj=0, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal 
    flag.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    
    has_next_obs = True if 'next_observations' in dataset.keys() else False

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    check_traj = 0
    check_traj1 = 0
    check_traj2 = 0
    traj_length = []
    
    
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        
        

        


        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            traj_length.append(episode_step)
            episode_step = 0
            check_traj += 1
            check_traj1 += 1
            continue  
        if done_bool or final_timestep:
            traj_length.append(episode_step)
            episode_step = 0

            # check num of traj
            check_traj += 1
            check_traj2 += 1
            
            
            if not has_next_obs:
                continue
            


        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        episode_step += 1

        if num_traj > 0 and num_traj==check_traj:
            break


    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
    }
    
def qlearning_dataset_distance(env, dataset=None, terminate_on_end=False, num_traj=0, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    
    has_next_obs = True if 'next_observations' in dataset.keys() else False

    N = dataset['rewards'].shape[0]
    in_num_traj, out_num_traj = num_traj, num_traj
    k=602
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    check_traj = 0
    check_traj1 = 0
    check_traj2 = 0
    traj_length = []

    traj_stack = []
    
    
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        
        

        


        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            traj, obs_, action_, next_obs_, reward_, done_ = nem_traj(obs_, action_, next_obs_, reward_, done_)
            traj_stack.append(traj)
            traj_length.append(episode_step)
            episode_step = 0
            check_traj += 1
            check_traj1 += 1
            continue  
        if done_bool or final_timestep:
            traj, obs_, action_, next_obs_, reward_, done_ = nem_traj(obs_, action_, next_obs_, reward_, done_)
            traj_stack.append(traj)
            traj_length.append(episode_step)
            episode_step = 0

            # check num of traj
            check_traj += 1
            check_traj2 += 1
            
            
            if not has_next_obs:
                continue
        
        
        
        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        episode_step += 1

        if num_traj > 0 and num_traj + k==check_traj:
            obs_.append(obs)
            next_obs_.append(new_obs)
            action_.append(action)
            reward_.append(reward)
            done_.append(done_bool)
            episode_step += 1
            break
    
    


def nem_traj(obs_, action_, next_obs_, reward_, done_):
    traj_obs = np.array(obs_)
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    
    return traj_obs, obs_, action_, next_obs_, reward_, done_
    
    

def qlearning_dataset_rewarr(env, dataset=None, terminate_on_end=False, reward_arrays=None, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    
    has_next_obs = True if 'next_observations' in dataset.keys() else False
    



    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    reward_arrays_ = [[] for _ in range(len(reward_arrays))]


    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    episode_step_list = []
    c1, c2 = 0, 0


    

    
    
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])

        

        
        
        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step_list.append(episode_step)
            episode_step = 0
            c1 += 1
            continue  
        if done_bool or final_timestep:
            episode_step_list.append(episode_step)
            episode_step = 0
            
            if not has_next_obs:
                c2 += 1
                continue

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)

        # Process each reward array here
        for j in range(len(reward_arrays)):
            reward = reward_arrays[j][i].astype(np.float32)
            reward_arrays_[j].append(reward)

        episode_step += 1

    new_dataset = {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
    }
    new_reward_arrays = np.array(reward_arrays_)
    arr_dict = dict(zip(range(len(new_reward_arrays)), new_reward_arrays))

    return new_dataset, arr_dict


class SequenceDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, max_len, max_ep_len=1000, device="cpu"):
        super().__init__()

        self.obs_dim = dataset["observations"].shape[-1]
        self.action_dim = dataset["actions"].shape[-1]
        self.max_len = max_len
        self.max_ep_len = max_ep_len
        self.device = torch.device(device)
        self.input_mean = np.concatenate([dataset["observations"], dataset["actions"]], axis=1).mean(0)
        self.input_std = np.concatenate([dataset["observations"], dataset["actions"]], axis=1).std(0) + 1e-6

        data_ = collections.defaultdict(list)
        
        use_timeouts = False
        if 'timeouts' in dataset:
            use_timeouts = True

        episode_step = 0
        self.trajs = []
        for i in range(dataset["rewards"].shape[0]):
            done_bool = bool(dataset['terminals'][i])
            if use_timeouts:
                final_timestep = dataset['timeouts'][i]
            else:
                final_timestep = (episode_step == 1000-1)
            for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']:
                data_[k].append(dataset[k][i])
            if done_bool or final_timestep:
                episode_step = 0
                episode_data = {}
                for k in data_:
                    episode_data[k] = np.array(data_[k])
                self.trajs.append(episode_data)
                data_ = collections.defaultdict(list)
            episode_step += 1
        
        indices = []
        for traj_ind, traj in enumerate(self.trajs):
            end = len(traj["rewards"])
            for i in range(end):
                indices.append((traj_ind, i, i+self.max_len))

        self.indices = np.array(indices)
        

        returns = np.array([np.sum(t['rewards']) for t in self.trajs])
        num_samples = np.sum([t['rewards'].shape[0] for t in self.trajs])
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        traj_ind, start_ind, end_ind = self.indices[idx]
        traj = self.trajs[traj_ind].copy()
        obss = traj['observations'][start_ind:end_ind]
        actions = traj['actions'][start_ind:end_ind]
        next_obss = traj['next_observations'][start_ind:end_ind]
        rewards = traj['rewards'][start_ind:end_ind].reshape(-1, 1)
        delta_obss = next_obss - obss
    
        # padding
        tlen = obss.shape[0]
        inputs = np.concatenate([obss, actions], axis=1)
        inputs = (inputs - self.input_mean) / self.input_std
        inputs = np.concatenate([inputs, np.zeros((self.max_len - tlen, self.obs_dim+self.action_dim))], axis=0)
        targets = np.concatenate([delta_obss, rewards], axis=1)
        targets = np.concatenate([targets, np.zeros((self.max_len - tlen, self.obs_dim+1))], axis=0)
        masks = np.concatenate([np.ones(tlen), np.zeros(self.max_len - tlen)], axis=0)

        inputs = torch.from_numpy(inputs).to(dtype=torch.float32, device=self.device)
        targets = torch.from_numpy(targets).to(dtype=torch.float32, device=self.device)
        masks = torch.from_numpy(masks).to(dtype=torch.float32, device=self.device)

        return inputs, targets, masks
