import torch
from torch.utils.data import Dataset
import glob
import os
from tqdm import tqdm
import bisect # 用于高效地查找索引

class SDImageDatasetOriginal(Dataset):
    def __init__(self, data_path, tokenizer_one, is_sdxl=False, tokenizer_two=None):
        self.data_path = data_path
        self.is_sdxl = is_sdxl
        self.tokenizer_one = tokenizer_one
        self.tokenizer_two = tokenizer_two

        print("Initializing dataset... Scanning files to map indices.")
        self.files = sorted(glob.glob(os.path.join(self.data_path, "*.pt")))
        
        self.cumulative_sizes = []
        total_samples = 0
        for f in tqdm(self.files, desc="Scanning data files"):
            # 只加载tensor来获取长度，避免加载大的字符串列表消耗内存
            # 假设每个文件都包含 'latents' key
            num_samples_in_file = torch.load(f)['latents'].shape[0]
            total_samples += num_samples_in_file
            self.cumulative_sizes.append(total_samples)
        
        # 如果文件中没有样本，cumulative_sizes会是空的
        if not self.cumulative_sizes:
            self.length = 0
        else:
            self.length = self.cumulative_sizes[-1]
            
        print(f"Dataset initialized. Found {self.length} samples in {len(self.files)} files.")

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if idx < 0 or idx >= self.length:
            raise IndexError("Index out of range")

        # 1. 找到该全局索引 'idx' 对应的文件
        # bisect_right 可以在 self.cumulative_sizes 中快速找到 idx 所在的位置
        file_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        
        # 2. 计算在该文件内的局部索引
        # 如果 file_idx > 0, 则从前一个文件的累积大小中减去
        local_idx = idx
        if file_idx > 0:
            local_idx = idx - self.cumulative_sizes[file_idx - 1]
            
        # 3. 加载对应的文件
        file_path = self.files[file_idx]
        data_dict = torch.load(file_path)
        
        # 4. 从加载的字典中提取数据
        # 假设您的 .pt 文件中 'latents' 和 'prompts' 的长度是一致的
        image = data_dict['latents'][local_idx]
        # 根据您的lmdb创建代码，prompts的key应该是 'prompts'
        prompt = data_dict['prompts'][local_idx] 
        
        # 5. 进行与 LMDB 版本中相同的处理
        image = image.to(dtype=torch.float32)

        text_input_ids_one = self.tokenizer_one(
            prompt, # 注意：这里传入的是单个字符串，而不是列表
            padding="max_length",
            max_length=self.tokenizer_one.model_max_length,
            truncation=True,
            return_tensors="pt",
        ).input_ids.squeeze(0) # squeeze(0) 去掉批次维度

        output_dict = { 
            'images': image,
            'text_input_ids_one': text_input_ids_one,
        }

        if self.is_sdxl:
            text_input_ids_two = self.tokenizer_two(
                prompt, # 同样是单个字符串
                padding="max_length",
                max_length=self.tokenizer_two.model_max_length,
                truncation=True,
                return_tensors="pt",
            ).input_ids.squeeze(0) # squeeze(0) 去掉批次维度
            output_dict['text_input_ids_two'] = text_input_ids_two

        return output_dict