from utils.lmdb import get_array_shape_from_lmdb, retrieve_row_from_lmdb, read_numpy_from_lmdb
from torch.utils.data import Dataset
import numpy as np
import torch
import lmdb
import json
from pathlib import Path
from PIL import Image
import os
import torchvision.transforms.functional as TF

class ShardingLMDB_T2V_Dataset(Dataset):
    def __init__(self, data_path: str, max_pair: int = int(1e8)):
        self.envs = []
        self.index = []
        self.env_len = dict()

        for fname in sorted(os.listdir(data_path)):
            path = os.path.join(data_path, fname)
            env = lmdb.open(path, readonly=True, lock=False, readahead=False, meminit=False)
            self.envs.append(env)

        for shard_id, env in enumerate(self.envs):
            self.env_len[shard_id] = get_array_shape_from_lmdb(env)   # (dataset_len) 
            for local_i in range(self.env_len[shard_id]):
                self.index.append((shard_id, local_i))

        self.max_pair = max_pair

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        shard_id, local_idx = self.index[idx]
        env = self.envs[shard_id]

        # 根据 store_data_dict_to_lmdb 的写入方式反向读取
        text_feature = read_numpy_from_lmdb(env, f"text_feature_{local_idx}", np.float32)
        neg_text_feature = read_numpy_from_lmdb(env, f"neg_text_feature_{local_idx}", np.float32)
        noise_shape = read_numpy_from_lmdb(env, f"noise_shape_{local_idx}", np.int64)

        prompt_arr  = read_numpy_from_lmdb(env, f"prompt_{local_idx}", np.uint8)
        prompt  = prompt_arr.tobytes().decode('utf-8')

        # 按需转换为 PyTorch tensor
        text_feature = torch.from_numpy(text_feature)
        neg_text_feature = torch.from_numpy(neg_text_feature)
        noise_shape = noise_shape

        # 返回格式与 Dataset 设计一致
        return {
            'text_feature': text_feature,  # 根据实际任务可调整
            'neg_text_feature': neg_text_feature,
            'noise_shape': noise_shape,
            'prompt': prompt
        }


class ShardingLMDB_SFT_Dataset(Dataset):
    def __init__(self, data_path: str, max_pair: int = int(1e8)):
        self.envs = []
        self.index = []
        self.latents_shape = dict()

        for fname in sorted(os.listdir(data_path)):
            path = os.path.join(data_path, fname)
            env = lmdb.open(path,
                            readonly=True,
                            lock=False,
                            readahead=False,
                            meminit=False)
            self.envs.append(env)

        for shard_id, env in enumerate(self.envs):
            self.latents_shape[shard_id] = get_array_shape_from_lmdb(env)   # (dataset_len) 
            for local_i in range(self.latents_shape[shard_id]):
                self.index.append((shard_id, local_i))

            # print("shard_id ", shard_id, " local_i ", local_i)

        self.max_pair = max_pair

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        shard_id, local_idx = self.index[idx]
        env = self.envs[shard_id]

        # 根据 store_data_dict_to_lmdb 的写入方式反向读取
        text_feature = read_numpy_from_lmdb(env, f"text_feature_{local_idx}", np.float32)
        neg_text_feature = read_numpy_from_lmdb(env, f"neg_text_feature_{local_idx}", np.float32)
        clip_fea = read_numpy_from_lmdb(env, f"clip_fea_{local_idx}", np.float16)
        noise_shape = read_numpy_from_lmdb(env, f"noise_shape_{local_idx}", np.int64)
        y = read_numpy_from_lmdb(env, f"y_{local_idx}", np.float32)
        latent = read_numpy_from_lmdb(env, f"latent_{local_idx}", np.float32)

        prompt_arr  = read_numpy_from_lmdb(env, f"prompt_{local_idx}", np.uint8)
        prompt  = prompt_arr.tobytes().decode('utf-8')

        # 按需转换为 PyTorch tensor
        text_feature = torch.from_numpy(text_feature)
        neg_text_feature = torch.from_numpy(neg_text_feature)
        clip_fea = torch.from_numpy(clip_fea)
        y = torch.from_numpy(y)
        latent = torch.from_numpy(latent)
        noise_shape = noise_shape

        # 返回格式与 Dataset 设计一致
        return {
            'text_feature': text_feature,  # 根据实际任务可调整
            'y': y,
            'clip_fea': clip_fea,
            'neg_text_feature': neg_text_feature,
            'noise_shape': noise_shape,
            'latent': latent,
            'prompt': prompt
        }

class ShardingLMDBDataset(Dataset):
    def __init__(self, data_path: str, max_pair: int = int(1e8)):
        self.envs = []
        self.index = []
        self.latents_shape = dict()

        for fname in sorted(os.listdir(data_path)):
            path = os.path.join(data_path, fname)
            env = lmdb.open(path,
                            readonly=True,
                            lock=False,
                            readahead=False,
                            meminit=False)
            self.envs.append(env)

        for shard_id, env in enumerate(self.envs):
            self.latents_shape[shard_id] = get_array_shape_from_lmdb(env)   # (dataset_len) 
            for local_i in range(self.latents_shape[shard_id]):
                self.index.append((shard_id, local_i))

            # print("shard_id ", shard_id, " local_i ", local_i)

        self.max_pair = max_pair

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        shard_id, local_idx = self.index[idx]
        env = self.envs[shard_id]

        # 根据 store_data_dict_to_lmdb 的写入方式反向读取
        text_feature = read_numpy_from_lmdb(env, f"text_feature_{local_idx}", np.float32)
        neg_text_feature = read_numpy_from_lmdb(env, f"neg_text_feature_{local_idx}", np.float32)
        clip_fea = read_numpy_from_lmdb(env, f"clip_fea_{local_idx}", np.float16)
        noise_shape = read_numpy_from_lmdb(env, f"noise_shape_{local_idx}", np.int64)
        y = read_numpy_from_lmdb(env, f"y_{local_idx}", np.float32)

        prompt_arr  = read_numpy_from_lmdb(env, f"prompt_{local_idx}", np.uint8)
        prompt  = prompt_arr.tobytes().decode('utf-8')

        # 按需转换为 PyTorch tensor
        text_feature = torch.from_numpy(text_feature)
        neg_text_feature = torch.from_numpy(neg_text_feature)
        clip_fea = torch.from_numpy(clip_fea)
        y = torch.from_numpy(y)
        noise_shape = noise_shape

        # 返回格式与 Dataset 设计一致
        return {
            'text_feature': text_feature,  # 根据实际任务可调整
            'y': y,
            'clip_fea': clip_fea,
            'neg_text_feature': neg_text_feature,
            'noise_shape': noise_shape,
            'prompt': prompt
        }

def cycle(dl, sampler):
    epoch = 0
    while True:
        sampler.set_epoch(epoch)          # 换随机种子
        for data in dl:                   # 完整跑完一个 epoch
            yield data
        epoch += 1 
        print(f"FINISH ONE EPOCH, now we are in EPOCH {epoch}")
