import numpy as np
import lmdb
import numpy as np

def save_numpy_to_lmdb(env: lmdb.Environment,
                       key: str,
                       arr: np.ndarray):
    """
    将一个 NumPy 数组及其形状写入 LMDB。
    实际保存两个键：
        key+'_data'  ->  ndarray.tobytes()
        key+'_shape' ->  np.int32 数组的形状
    """
    with env.begin(write=True) as txn:
        data_key = f"{key}_data".encode()
        shape_key = f"{key}_shape".encode()

        txn.put(data_key, arr.tobytes())
        txn.put(shape_key, np.array(arr.shape, dtype=np.int32).tobytes())

def store_data_dict_to_lmdb(env: lmdb.Environment,
                            data_dict: dict,
                            start_index: int, 
                            sft=False, rl=False):
    """
    仅写入 data_dict 中的 5 个字段：
        text_feature, neg_text_feature, clip_fea, noise_shape, y
    """
    idx = start_index  # 每调用一次只写一个样本

    # 1) text_feature  -> 字符串
    save_numpy_to_lmdb(env, f"text_feature_{idx}", data_dict["text_feature"])
    # 2) neg_text_feature  -> 字符串
    save_numpy_to_lmdb(env, f"neg_text_feature_{idx}", data_dict["neg_text_feature"])
    # 3) clip_fea
    save_numpy_to_lmdb(env, f"clip_fea_{idx}", data_dict["clip_fea"])
    # 4) noise_shape
    save_numpy_to_lmdb(env, f"noise_shape_{idx}", data_dict["noise_shape"])
    # 5) y
    save_numpy_to_lmdb(env, f"y_{idx}", data_dict["y"])
    if sft:
        # 6) latent
        save_numpy_to_lmdb(env, f"latent_{idx}", data_dict["latent"])
        
    if rl:
        # 7) prompt（字符串）
        prompt = data_dict["prompt"]
        # 转成字节串后当成 uint8 数组
        prompt_bytes = np.frombuffer(prompt.encode('utf-8'), dtype=np.uint8)
        save_numpy_to_lmdb(env, f"prompt_{idx}", prompt_bytes)

def read_numpy_from_lmdb(env: lmdb.Environment,
                         key: str,
                         dtype: np.dtype) -> np.ndarray:
    """
    从 LMDB 中读出指定 key 的 NumPy 数组并恢复形状。
    需要传入正确的 dtype，否则会解析错误。
    """
    with env.begin() as txn:
        data_key = f"{key}_data".encode()
        shape_key = f"{key}_shape".encode()

        buf = txn.get(data_key)
        shape = np.frombuffer(txn.get(shape_key), dtype=np.int32)

        arr = np.frombuffer(buf, dtype=dtype).reshape(shape)
        return arr.copy()        # .copy() 把只读 buffer 变成可写数组

def get_array_shape_from_lmdb(env):
    with env.begin() as txn:
        buf = txn.get("local_dataset_len".encode())
        local_dataset_len = int(buf.decode())
    return local_dataset_len


def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
    """
    Store rows of multiple numpy arrays in a single LMDB.
    Each row is stored separately with a naming convention.
    """
    with env.begin(write=True) as txn:
        for array_name, array in arrays_dict.items():
            for i, row in enumerate(array):
                # Convert row to bytes
                if isinstance(row, str):
                    row_bytes = row.encode()
                else:
                    row_bytes = row.tobytes()

                print(f"-------saved {array_name}_{start_index + i}_data")
                data_key = f'{array_name}_{start_index + i}_data'.encode()

                txn.put(data_key, row_bytes)


def process_data_dict(data_dict, seen_prompts):
    output_dict = {}

    all_videos = []
    all_prompts = []
    for prompt, video in data_dict.items():
        if prompt in seen_prompts:
            continue
        else:
            seen_prompts.add(prompt)

        video = video.half().numpy()
        all_videos.append(video)
        all_prompts.append(prompt)

    if len(all_videos) == 0:
        print("no video found!")
        return {"latents": np.array([]), "prompts": np.array([])}

    all_videos = np.concatenate(all_videos, axis=0)

    output_dict['latents'] = all_videos
    output_dict['prompts'] = np.array(all_prompts)

    return output_dict


def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape=None):
    """
    Retrieve a specific row from a specific array in the LMDB.
    """
    data_key = f'{array_name}_{row_index}_data'.encode()

    with lmdb_env.begin() as txn:
        row_bytes = txn.get(data_key)

    if dtype == str:
        array = row_bytes.decode()
    else:
        array = np.frombuffer(row_bytes, dtype=dtype)

    if shape is not None and len(shape) > 0:
        array = array.reshape(shape)
    return array
