from typing import Optional
import numpy as np
import numba
from diffusion_policy.common.replay_buffer import ReplayBuffer


@numba.jit(nopython=True)
def create_indices(
        episode_ends: np.ndarray, sequence_length: int,
        episode_mask: np.ndarray,
        pad_before: int = 0, pad_after: int = 0,
        debug: bool = True) -> np.ndarray:
    episode_mask.shape == episode_ends.shape
    pad_before = min(max(pad_before, 0), sequence_length - 1)
    pad_after = min(max(pad_after, 0), sequence_length - 1)

    indices = list()
    for i in range(len(episode_ends)):
        if not episode_mask[i]:
            # skip episode
            continue
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i - 1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start + 1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx + start_idx)
            end_offset = (idx + sequence_length + start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            if debug:
                assert (start_offset >= 0)
                assert (end_offset >= 0)
                assert (sample_end_idx - sample_start_idx) == (buffer_end_idx - buffer_start_idx)
            indices.append([
                buffer_start_idx, buffer_end_idx,
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices


def get_val_mask(n_episodes, val_ratio, seed=0):
    val_mask = np.zeros(n_episodes, dtype=bool)
    if val_ratio <= 0:
        return val_mask

    # have at least 1 episode for validation, and at least 1 episode for train
    n_val = min(max(1, round(n_episodes * val_ratio)), n_episodes - 1)
    rng = np.random.default_rng(seed=seed)
    val_idxs = rng.choice(n_episodes, size=n_val, replace=False)
    val_mask[val_idxs] = True
    return val_mask


def downsample_mask(mask, max_n, seed=0):
    # subsample training data
    train_mask = mask
    if (max_n is not None) and (np.sum(train_mask) > max_n):
        n_train = int(max_n)
        curr_train_idxs = np.nonzero(train_mask)[0]
        rng = np.random.default_rng(seed=seed)
        train_idxs_idx = rng.choice(len(curr_train_idxs), size=n_train, replace=False)
        train_idxs = curr_train_idxs[train_idxs_idx]
        train_mask = np.zeros_like(train_mask)
        train_mask[train_idxs] = True
        assert np.sum(train_mask) == n_train
    return train_mask

class SequenceSampler:
    def __init__(self,
                 replay_buffer: ReplayBuffer,
                 sequence_length: int,
                 pad_before: int = 0,
                 pad_after: int = 0,
                 keys=None,
                 key_first_k=dict(),
                 episode_mask: Optional[np.ndarray] = None,
                 ):
        super().__init__()
        assert sequence_length >= 1
        if keys is None:
            keys = list(replay_buffer.keys())

        episode_ends = replay_buffer.episode_ends[:]
        if episode_mask is None:
            episode_mask = np.ones(episode_ends.shape, dtype=bool)

        if np.any(episode_mask):
            base_indices = create_indices(
                episode_ends,
                sequence_length=sequence_length,
                pad_before=pad_before,
                pad_after=pad_after,
                episode_mask=episode_mask,
            )
        else:
            base_indices = np.zeros((0, 4), dtype=np.int64)

        self.sequence_length = sequence_length
        self.replay_buffer = replay_buffer
        self.keys = list(keys)
        self.key_first_k = key_first_k
        self.indices = base_indices

        self.indices_dict = {}
        for key in self.keys:

            k = self.key_first_k.get(key, sequence_length)
            indices_k = base_indices.copy()

            if key == "action":
                self.indices_dict[key] = indices_k
                continue

            for i in range(len(indices_k)):
                buf_s, buf_e, samp_s, samp_e = indices_k[i]
                n_data = buf_e - buf_s
                k_data = min(k, n_data)
                indices_k[i][1] = buf_s + k_data
                indices_k[i][2] = 0
                indices_k[i][3] = min(k_data, samp_e)

            self.indices_dict[key] = indices_k

    def __len__(self):
        return len(self.indices)

    def sample_sequence(self, idx):
        result = {}
        for key in self.keys:
            input_arr = self.replay_buffer[key]
            indices = self.indices_dict[key]
            bs, be, ss, se = indices[idx]
            L = self.key_first_k.get(key, self.sequence_length)
            sample = input_arr[bs:be]
            if (ss > 0) or (se < L):
                data = np.zeros((L,) + input_arr.shape[1:], dtype=input_arr.dtype)
                if ss > 0:
                    data[:ss] = sample[0]
                if se < L:
                    data[se:] = sample[-1]
                data[ss:se] = sample
            else:
                data = sample
            result[key] = data
        return result

