import os
import numpy as np
from src.data.base_dataset import TTADatasetBase, DatumRaw
from robustbench.data import load_cifar10c, load_cifar100c

class CorruptionCIFAR_Lazy(TTADatasetBase):
    def __init__(self, cfg, all_corruption, all_severity):
        all_corruption = [all_corruption] if not isinstance(all_corruption, list) else all_corruption
        all_severity = [all_severity] if not isinstance(all_severity, list) else all_severity

        self.cfg = cfg
        self.corruptions = all_corruption
        self.severity = all_severity
        self.load_image = None

        if cfg.CORRUPTION.DATASET == "cifar10c":
            self.dataset = "cifar10"
            self.load_image = load_cifar10c
        elif cfg.CORRUPTION.DATASET == "cifar100c":
            self.dataset = "cifar100"
            self.load_image = load_cifar100c
        else:
            raise ValueError(f"Unsupported dataset: {cfg.CORRUPTION.DATASET}")

        self.data_dir = cfg.DATA_DIR
        self.domain_id_to_name = {}
        self.data_source = []

        # 只加载 labels（这部分非常小）
        self.labels = np.load(os.path.join(self.data_dir, self.dataset + "-c", "labels.npy"))

        sample_index = 0
        for i_s, severity in enumerate(self.severity):
            for i_c, corruption in enumerate(self.corruptions):
                d_name = f"{corruption}_{severity}"
                d_id = i_s * len(self.corruptions) + i_c
                self.domain_id_to_name[d_id] = d_name

                # 只加载数据维度，不实际加载全部图
                data_path = os.path.join(self.data_dir, self.dataset + "-c", f"{corruption}.npy")
                if not os.path.exists(data_path):
                    raise FileNotFoundError(f"Missing corruption file: {data_path}")

                num_samples = np.load(data_path, mmap_mode='r').shape[0]
                for i in range(min(num_samples, cfg.CORRUPTION.NUM_EX)):
                    # 只记录样本元信息，不加载图像
                    self.data_source.append(DatumRaw((corruption, severity, i), 
                                                     self.labels[i].item(), d_id))
                    sample_index += 1

        super().__init__(cfg, self.data_source)

    def __getitem__(self, index):
        item = self.data_source[index]
        corruption, severity, sample_idx = item.x  # x 现在是元组，不是真图

        # 读取 mmap 的图像数据
        data_path = os.path.join(self.data_dir, self.dataset + "-c", f"{corruption}.npy")
        x_data = np.load(data_path, mmap_mode='r')[sample_idx]

        # 转为 tensor，标准化
        import torch
        x = torch.tensor(x_data).permute(2, 0, 1).float() / 255.0
        y = torch.tensor(item.y).long()
        d = item.domain

        return {"image": x, "label": y, "domain": d}
