import zarr
import numpy as np
import cv2
import os
from tqdm import tqdm
import shutil
from diffusion_policy.env.pusht.pusht_env import PushTEnv


def add_rewards_to_zarr(
    input_zarr_path: str,
    output_zarr_path: str,
    env_kwargs: dict = None
) -> None:
    """
    Read trajectories from `input_zarr_path`, for each step (s, a) replay in PushTEnv
    to get r = reward(s'), then write all data + reward to `output_zarr_path`.
    
    Args:
        input_zarr_path: path to existing .zarr with /data and /meta/episode_ends
        output_zarr_path: path where new .zarr will be created (overwritten if exists)
        env_kwargs: dict of kwargs to pass to PushTEnv(...), e.g. {'render_action': False}
    """
    env_kwargs = env_kwargs or {}
    # Open input
    in_root = zarr.open(input_zarr_path, mode='r')
    data_in = in_root['data']
    ends_in = in_root['meta']['episode_ends'][:]

    # Prepare output Zarr
    if os.path.exists(output_zarr_path):
        import shutil; shutil.rmtree(output_zarr_path)
    out_root = zarr.open(output_zarr_path, mode='w')
    data_out = out_root.create_group('data')
    meta_out = out_root.create_group('meta')

    # Copy data keys, create appendable datasets
    zarr_store = {}
    for key in data_in.array_keys():
        arr = data_in[key]
        zarr_store[key] = data_out.create_dataset(
            name=key,
            shape=(0,) + arr.shape[1:],           # start empty
            chunks=arr.chunks,
            dtype=arr.dtype,
            compressor=arr.compressor,
            overwrite=True,
            append_dim=0
        )
    # create reward array
    zarr_store['reward'] = data_out.create_dataset(
        name='reward',
        shape=(0,),
        chunks=(1024,),
        dtype='f4',
        overwrite=True,
        append_dim=0
    )

    # iterate episodes
    starts = np.concatenate([[0], ends_in[:-1]])
    total = 0
    new_ends = []

    for ep_idx, (s, e) in enumerate(zip(starts, ends_in)):
        # slice the episode from input
        episode = {k: data_in[k][s:e] for k in data_in.array_keys()}
        T = e - s

        # compute reward sequence r[i] = reward stepping from state[i] with action[i]
        rewards = np.zeros((T,), dtype=np.float32)
        env = PushTEnv(**env_kwargs)
        for t in range(T):
            st = episode['state'][t]
            act = episode['action'][t]
            # reset to this state, then step
            env.reset_to_state = st
            env.reset()
            _, r, _, _ = env.step(act)
            rewards[t] = np.float32(r)
            # print(f"Episode {ep_idx} step {t}: reward {r}")

        # append all arrays
        for k in data_in.array_keys():
            zarr_store[k].append(episode[k].astype(data_in[k].dtype))
        zarr_store['reward'].append(rewards)

        total += T
        new_ends.append(total)

    # write new episode_ends
    meta_out.create_dataset('episode_ends', data=np.array(new_ends, dtype=np.int64))

    print(f"✅ Written with rewards to {output_zarr_path} "
          f"({len(new_ends)} episodes, {total} steps).")
   