from typing import Optional
import numpy as np
import numba
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.common.replay_buffer_d4rl import ReplayBufferD4RL


@numba.jit(nopython=True)
def create_indices(
    episode_ends:np.ndarray, sequence_length:int, 
    episode_mask: np.ndarray,
    pad_before: int=0, pad_after: int=0,
    debug:bool=True) -> np.ndarray:
    episode_mask.shape == episode_ends.shape        
    pad_before = min(max(pad_before, 0), sequence_length-1)
    pad_after = min(max(pad_after, 0), sequence_length-1)

    indices = list()
    for i in range(len(episode_ends)):
        if not episode_mask[i]:
            # skip episode
            continue
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx
        
        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            if debug:
                assert(start_offset >= 0)
                assert(end_offset >= 0)
                assert (sample_end_idx - sample_start_idx) == (buffer_end_idx - buffer_start_idx)
            indices.append([
                buffer_start_idx, buffer_end_idx, 
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices






@numba.jit(nopython=True)
def create_indices_new(
    path_lengths:np.ndarray, sequence_length:int, 
    episode_mask: np.ndarray,
    pad_before: int=0, pad_after: int=0,
    debug:bool=True) -> np.ndarray:


    pad_before = min(max(pad_before, 0), sequence_length-1)
    pad_after = min(max(pad_after, 0), sequence_length-1)

    indices = []
    for i, path_length in enumerate(path_lengths):
        # max_start = min(path_length - 1, self.max_path_length - sequence_length)
        # if not self.use_padding:
        max_start = min(path_length - 1, path_length - sequence_length)
        for start in range(max_start):
            end = start + sequence_length
            indices.append((i, start, end))
    indices = np.array(indices)
    return indices





def get_val_mask(n_episodes, val_ratio, seed=0):
    val_mask = np.zeros(n_episodes, dtype=bool)
    if val_ratio <= 0:
        return val_mask

    # have at least 1 episode for validation, and at least 1 episode for train
    n_val = min(max(1, round(n_episodes * val_ratio)), n_episodes-1)
    rng = np.random.default_rng(seed=seed)
    val_idxs = rng.choice(n_episodes, size=n_val, replace=False)
    val_mask[val_idxs] = True
    return val_mask


def downsample_mask(mask, max_n, seed=0):
    # subsample training data
    train_mask = mask
    if (max_n is not None) and (np.sum(train_mask) > max_n):
        n_train = int(max_n)
        curr_train_idxs = np.nonzero(train_mask)[0]
        rng = np.random.default_rng(seed=seed)
        train_idxs_idx = rng.choice(len(curr_train_idxs), size=n_train, replace=False)
        train_idxs = curr_train_idxs[train_idxs_idx]
        train_mask = np.zeros_like(train_mask)
        train_mask[train_idxs] = True
        assert np.sum(train_mask) == n_train
    return train_mask

class SequenceSampler:
    def __init__(self, 
        replay_buffer: ReplayBuffer, 
        sequence_length:int,
        pad_before:int=0,
        pad_after:int=0,
        keys=None,
        key_first_k=dict(),
        episode_mask: Optional[np.ndarray]=None,
        ):
        """
        key_first_k: dict str: int
            Only take first k data from these keys (to improve perf)
        """

        super().__init__()
        assert(sequence_length >= 1)
        # ['action', 'agentview_image', 'rewards', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_eye_in_hand_image', 'robot0_gripper_qpos']
        if keys is None:
            keys = list(replay_buffer.keys())
        
        episode_ends = replay_buffer.episode_ends[:]
        if episode_mask is None:
            episode_mask = np.ones(episode_ends.shape, dtype=bool)

        if np.any(episode_mask):
            indices = create_indices(episode_ends, 
                sequence_length=sequence_length, 
                pad_before=pad_before, 
                pad_after=pad_after,
                episode_mask=episode_mask
                )
        else:
            indices = np.zeros((0,4), dtype=np.int64)

        # (buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx)
        self.indices = indices 
        self.keys = list(keys) # prevent OmegaConf list performance problem
        self.sequence_length = sequence_length
        self.replay_buffer = replay_buffer
        self.key_first_k = key_first_k
    
    def __len__(self):
        return len(self.indices)
        
    def sample_sequence(self, idx):
        buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx \
            = self.indices[idx]
        result = dict()
        for key in self.keys:
            input_arr = self.replay_buffer[key]
            # performance optimization, avoid small allocation if possible
            if key not in self.key_first_k:
                sample = input_arr[buffer_start_idx:buffer_end_idx]
            else:
                # performance optimization, only load used obs steps
                n_data = buffer_end_idx - buffer_start_idx
                k_data = min(self.key_first_k[key], n_data)
                # fill value with Nan to catch bugs
                # the non-loaded region should never be used
                sample = np.full((n_data,) + input_arr.shape[1:], 
                    fill_value=np.nan, dtype=input_arr.dtype)
                try:
                    sample[:k_data] = input_arr[buffer_start_idx:buffer_start_idx+k_data]
                except Exception as e:
                    import pdb; pdb.set_trace()
            data = sample
            # if (buffer_end_idx - buffer_start_idx != self.sequence_length):
            #     print(buffer_start_idx - buffer_end_idx)
            if (sample_start_idx > 0) or (sample_end_idx < self.sequence_length):
                data = np.zeros(
                    shape=(self.sequence_length,) + input_arr.shape[1:],
                    dtype=input_arr.dtype)
                if sample_start_idx > 0:
                    data[:sample_start_idx] = sample[0]
                if sample_end_idx < self.sequence_length:
                    data[sample_end_idx:] = sample[-1]
                data[sample_start_idx:sample_end_idx] = sample
            result[key] = data
        return result






class NewSequenceSampler:
    def __init__(self, 
        replay_buffer: ReplayBufferD4RL, 
        sequence_length:int,
        pad_before:int=0,
        pad_after:int=0,
        keys=None,
        key_first_k=dict(),
        episode_mask: Optional[np.ndarray]=None,
        ):
        """
        key_first_k: dict str: int
            Only take first k data from these keys (to improve perf)
        """

        super().__init__()
        self.pad_before = pad_before
        self.pad_after = pad_after
        assert(sequence_length >= 1)
        # ['action', 'agentview_image', 'rewards', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_eye_in_hand_image', 'robot0_gripper_qpos']
        if keys is None:
            keys = list(replay_buffer._dict.keys())
        
        path_lengths = replay_buffer['path_lengths']


        # if np.any(episode_mask):
        indices = create_indices_new(path_lengths, 
            sequence_length=sequence_length, 
            pad_before=pad_before, 
            pad_after=pad_after,
            episode_mask=episode_mask
            )
        # else:
        #     indices = np.zeros((0,4), dtype=np.int64)

        # (buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx)
        self.indices = indices 
        self.keys = list(keys) # prevent OmegaConf list performance problem
        self.sequence_length = sequence_length
        self.replay_buffer = replay_buffer
        self.key_first_k = key_first_k
    
    def __len__(self):
        return len(self.indices)
    
    def sample_sequence(self, idx):
        # trajectory_idx, buffer_start_idx, buffer_end_idx = self.indices[idx]
        # result = dict()
        # ### ['path_lengths', 'actions', 'infos/action_log_probs', 'infos/qpos', 'infos/qvel', 'next_observations', 'observations', 'rewards', 'terminals', 'timeouts']
        # for key in self.keys:
        #     input_arr = self.replay_buffer[key]
        #     # performance optimization, avoid small allocation if possible
        #     if "path_lengths" in key:
        #         continue
        #     if key not in self.key_first_k:
        #         sample = input_arr[trajectory_idx, buffer_start_idx:buffer_end_idx,...]
        #     else:
        #         # performance optimization, only load used obs steps
        #         n_data = buffer_end_idx - buffer_start_idx
        #         k_data = min(self.key_first_k[key], n_data)
        #         # fill value with Nan to catch bugs
        #         # the non-loaded region should never be used
        #         sample = np.full((n_data,) + input_arr.shape[1:], 
        #             fill_value=np.nan, dtype=input_arr.dtype)
        #         try:
        #             sample[:k_data] = input_arr[trajectory_idx, buffer_start_idx:buffer_start_idx+k_data,...]
        #         except Exception as e:
        #             import pdb; pdb.set_trace()
        #     data = sample
        #     # if (sample_start_idx > 0) or (sample_end_idx < self.sequence_length):
        #     #     data = np.zeros(
        #     #         shape=(self.sequence_length,) + input_arr.shape[1:],
        #     #         dtype=input_arr.dtype)
        #     #     if sample_start_idx > 0:
        #     #         data[:sample_start_idx] = sample[0]
        #     #     if sample_end_idx < self.sequence_length:
        #     #         data[sample_end_idx:] = sample[-1]
        #     #     data[sample_start_idx:sample_end_idx] = sample
        #     result[key] = data

        result = dict()
        path_ind, start, end = self.indices[idx]
        # for key in self.keys:
            # input_arr = self.replay_buffer[key]
            # print(input_arr.shape)
            # print(self.replay_buffer.observations.shape)
            # input_arr = input_arr[path_ind, start:end]
            # result[key] = input_arr
        # observations = self.replay_buffer.observations[path_ind, start:end]
        # actions = self.replay_buffer.actions[path_ind, start:end]
        # rewards = self.replay_buffer.rewards[path_ind, start:end]

        observations = self.replay_buffer.normed_observations[path_ind, start:end]
        actions = self.replay_buffer.normed_actions[path_ind, start:end]
        rewards = self.replay_buffer.rewards[path_ind, start:end]

        result['observations'] = observations
        result['actions'] = actions
        result['rewards'] = rewards

        return result