import numpy as np
from utils.datasets import Dataset


def build_two_step_dataset(
    ds,
    done_key='terminals',  # flag that marks the LAST step
    timeout_key=None,  # optional truncation flag
    shift_done_left=False,  # True if the flag is at t+1
):
    """
    Convert a 1-step Dataset/ReplayBuffer into a stride-2 dataset that stores
    (o_t, [a_t, a_{t+1}]) pairs and all mandatory bookkeeping tensors.

    Each primitive action appears exactly once in the new buffer. The
    terminal flag of a pair is copied from the second action (idx+1), so
    `terminals` is 1 only for pairs that actually end an episode.

    The resulting action will have shape (2, action_dim) representing
    a sequence of 2 actions.
    """
    # Pull raw arrays
    obs_arr = ds['observations']  # (T, *obs_dim)
    act_arr = ds['actions']  # (T, act_dim)
    done = ds[done_key].astype(bool)  # (T,)

    if timeout_key and timeout_key in ds:
        done = np.logical_or(done, ds[timeout_key].astype(bool))

    if shift_done_left:  # env recorded "done" at t+1
        done = np.roll(done, -1)
        done[-1] = True

    has_next_obs = 'next_observations' in ds
    if has_next_obs:
        next_obs_arr = ds['next_observations']
    has_rewards = 'rewards' in ds
    if has_rewards:
        rew_arr = ds['rewards'].astype(np.float32)

    # Walk through episodes and create stride-2 pairs
    obs_new, act_new, next_obs_new, term_new, rew_new = [], [], [], [], []
    ep_start = 0

    for t, d in enumerate(done):
        if d:  # episode ends at index t
            L = t + 1 - ep_start  # episode length
            if L >= 2:  # need at least two actions
                # Choose starting index based on parity to avoid duplicates
                # If L is even, start from 0; if odd, start from 1
                # This ensures each action appears exactly once
                first = 0 if L % 2 == 0 else 1

                for i in range(first, L - 1, 2):
                    idx = ep_start + i
                    obs_new.append(obs_arr[idx])

                    # Stack actions along a new axis to get shape (2, action_dim)
                    act_pair = np.concatenate([act_arr[idx], act_arr[idx + 1]], axis=-1)  # shape: (2 * action_dim,)
                    # act_pair = np.stack([act_arr[idx], act_arr[idx + 1]], axis=0)
                    act_new.append(act_pair)

                    # Terminal flag comes from the second action in the pair
                    term_new.append(done[idx + 1])

                    if has_next_obs:
                        # Next observation is after the second action
                        next_obs_new.append(next_obs_arr[idx + 1])
                    if has_rewards:
                        # You can aggregate rewards however you want
                        # Here we sum the rewards from both steps
                        rew_new.append(rew_arr[idx] + rew_arr[idx + 1])

            ep_start = t + 1  # next episode starts here

    # Convert lists to arrays
    if len(obs_new) == 0:
        # Handle empty dataset case
        obs_shape = obs_arr.shape[1:]
        act_shape = (2,) + act_arr.shape[1:]

        obs_new = np.empty((0,) + obs_shape, dtype=obs_arr.dtype)
        act_new = np.empty((0,) + act_shape, dtype=act_arr.dtype)
        term_new = np.empty((0,), dtype=ds[done_key].dtype)
    else:
        obs_new = np.asarray(obs_new, dtype=obs_arr.dtype)
        act_new = np.asarray(act_new, dtype=act_arr.dtype)
        term_new = np.asarray(term_new, dtype=ds[done_key].dtype)

    data = dict(observations=obs_new, actions=act_new, terminals=term_new, masks=1.0 - term_new)

    if has_next_obs:
        if len(next_obs_new) == 0:
            next_obs_shape = next_obs_arr.shape[1:]
            next_obs_new = np.empty((0,) + next_obs_shape, dtype=next_obs_arr.dtype)
        else:
            next_obs_new = np.asarray(next_obs_new, dtype=next_obs_arr.dtype)
        data['next_observations'] = next_obs_new

    if has_rewards:
        if len(rew_new) == 0:
            data['rewards'] = np.empty((0,), dtype=rew_arr.dtype)
        else:
            data['rewards'] = np.asarray(rew_new, dtype=rew_arr.dtype)

    # Print debug info
    print(f'Original dataset size: {len(obs_arr)}')
    print(f'New dataset size: {len(obs_new)}')
    if len(act_new) > 0:
        print(f'Original action shape: {act_arr[0].shape}')
        print(f'New action shape: {act_new[0].shape}')
    for key, value in data.items():
        print(f'{key}: {value.shape}')

    # Create and return the new dataset
    return Dataset.create(**data)


def discounted_sum(rewards: np.ndarray, terminals: np.ndarray, discount: float) -> np.ndarray:
    """
    Compute R_t =  Σ_{k=t}^{T-1} discount^{k-t} r_k  for every time-step t.
    `terminals` must be 1 at terminal steps, 0 elsewhere.
    If you pass an entire buffer, trajectories are assumed to be concatenated.
    """
    rtg = np.zeros_like(rewards, dtype=float)
    running_sum = 0.0
    for t in reversed(range(len(rewards))):
        running_sum = rewards[t] + discount * running_sum * (1 - terminals[t])
        rtg[t] = running_sum
        if terminals[t]:  # reset at episode boundary
            running_sum = 0.0
    return rtg
