import fnmatch
import json
import torch
from torchvision.datasets import VisionDataset, ImageFolder
from torchvision.datasets.folder import default_loader
from PIL import Image
import warnings
from sys import exit
import shutil
from configs.femnist import IMG_DIM
import os
import cv2
import numpy as np
import lmdb
from shutil import move
from tqdm import tqdm
from torchvision import transforms

class FEMNIST(VisionDataset):
    """
    classes: 10 digits, 26 lower cases, 26 upper cases.
    We use torch.save, torch.load in this dataset
    """

    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets

    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets

    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data

    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, user_list: list = None):
        super(FEMNIST, self).__init__(root, transform=transform, target_transform=target_transform)
        """
        0 <= any user in user_list < total_users
        """
        self.train = train
        self.user_list = user_list

        if download:
            self.download()

        if not self._check_exists():
            raise FileNotFoundError("Dataset not found. You can use download=True to download it")

        self.total_num_users = torch.load(os.path.join(self.processed_folder, "num_users.pt"),weights_only=True)

        if self.user_list is not None:
            self.num_users = len(self.user_list)
        else:
            self.user_list = [i for i in range(self.total_num_users)]
            self.num_users = self.total_num_users

        if self.train:
            self.data, self.targets = self.load(train=True)
        else:
            self.data, self.targets = self.load(train=False)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        # Needs 0~255, uint8 scale
        img = Image.fromarray(np.uint8(255 * (1 - img.numpy())), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)

    @property
    def raw_folder(self):
        return os.path.join(self.root, "raw")

    @property
    def all_data_folder(self):
        return os.path.join(self.root, "femnist", "data", "raw_data")

    @property
    def processed_folder(self):
        return os.path.join(self.root, "processed")

    def _check_exists(self):
        return (os.path.exists(os.path.join(self.raw_folder, "train")) and
                os.path.exists(os.path.join(self.raw_folder, "test")) and
                os.path.exists(os.path.join(self.processed_folder, "num_users.pt")))

    def download(self):
        if self._check_exists():
            # print("Data already downloaded.")
            return

        if os.path.isdir(self.raw_folder) and len(os.listdir(self.raw_folder)) != 0:
            self.process()
            return

        root = self.root
        if not os.path.isdir(root):
            os.mkdir(root)

        if not os.path.exists(self.all_data_folder):
            # download from https://github.com/TalwalkarLab/leaf/tree/master/data/femnist
            input_str = input("Downloading and processing data will take "
                              "approximately 10 to 30 minutes, and it consumes about 15GB of storage. Continue? [y/n]")
            if input_str.lower() in ["y", "yes"]:
                os.system(rf"git clone https://github.com/TalwalkarLab/leaf.git {root}/github_repo"
                          rf"&& cd {root}/github_repo/data/femnist"
                          r"&& ./preprocess.sh -s niid --sf 0.05 -k 0 -t sample"
                          r"&& cd ../../.."
                          r"&& mv github_repo/data/utils utils"
                          r"&& mv github_repo/data/femnist femnist"
                          r"&& rm -rf github_repo")
                os.makedirs(self.raw_folder, exist_ok=False)
                os.system(rf"cd {root}"
                          r"&& rm -r femnist/data/rem_user_data femnist/data/sampled_data"
                          r"&& mv femnist/data/test raw/ && mv femnist/data/train raw/")
            else:
                print("Exiting...")
                exit()
        else:
            if os.path.exists(os.path.join(root, "data", "rem_user_data")):
                os.system(rf"rm -r {root}/data/rem_user_data")
            if os.path.exists(os.path.join(root, "data", "sampled_data")):
                os.system(rf"rm -r {root}/data/sampled_data")
            if os.path.exists(os.path.join(root, "data", "train")):
                os.system(rf"rm -r {root}/data/train")
            if os.path.exists(os.path.join(root, "data", "test")):
                os.system(rf"rm -r {root}/data/test")
            if os.path.exists(os.path.join(root, "raw")):
                os.system(rf"rm -r {root}/raw")
            if os.path.exists(os.path.join(root, "processed")):
                os.system(rf"rm -r {root}/processed")

            os.makedirs(self.raw_folder, exist_ok=False)
            os.system(rf"cd {root}/femnist"
                      r"&& ./preprocess.sh -s niid --sf 0.05 -k 0 -t sample"
                      r"&& cd .."
                      r"&& rm -r femnist/data/rem_user_data femnist/data/sampled_data"
                      r"&& mv femnist/data/test raw/ && mv femnist/data/train raw/")

        self.process()

    def process(self):
        print("Processing data...")

        if not os.path.isdir(self.processed_folder):
            os.makedirs(self.processed_folder)

        total_users_train = 0
        list_train_f = [f for f in os.listdir(os.path.join(self.raw_folder, "train")) if
                        fnmatch.fnmatch(f, "*.json")]
        list_train_f.sort(key=lambda fname: int(fname[9:-28]))

        for filename in list_train_f:
            with open(os.path.join(self.raw_folder, "train", filename)) as file:
                data = json.load(file)
                for user_name, val in data["user_data"].items():
                    # key: user name
                    # val: dict {x: x_data, y: y_data}
                    x = torch.tensor(val["x"]).reshape((-1, *IMG_DIM))
                    y = torch.tensor(val["y"])

                    torch.save((x, y), os.path.join(self.processed_folder, "train_{}.pt".format(total_users_train)))
                    total_users_train += 1

        total_users_test = 0
        list_test_f = [f for f in os.listdir(os.path.join(self.raw_folder, "test")) if fnmatch.fnmatch(f, "*.json")]
        list_test_f.sort(key=lambda fname: int(fname[9:-27]))

        for filename in list_test_f:
            with open(os.path.join(self.raw_folder, "test", filename)) as file:
                data = json.load(file)
                for user_name, val in data["user_data"].items():
                    # key: user name
                    # val: dict {x: x_data, y: y_data}
                    x = torch.tensor(val["x"]).reshape((-1, *IMG_DIM))
                    y = torch.tensor(val["y"])

                    torch.save((x, y), os.path.join(self.processed_folder, "test_{}.pt").format(total_users_test))
                    total_users_test += 1

        assert total_users_train == total_users_test
        torch.save(total_users_train, os.path.join(self.processed_folder, "num_users.pt"))
        print("Done. {} users processed.".format(total_users_train))

    def load(self, train):
        if train:
            prf = "train"
        else:
            prf = "test"

        data_list, label_list = [], []
        for user_id in self.user_list:
            x, y = torch.load(
                os.path.join(self.processed_folder, "{}_{}.pt".format(prf, user_id)),
                weights_only=True  # 仅加载权重，避免 FutureWarning
            )
            data_list.append(x)
            label_list.append(y)
        return torch.cat(data_list, dim=0), torch.cat(label_list, dim=0)


class CelebA(VisionDataset):
    """
    The Leaf CelebA dataset. See "https://github.com/TalwalkarLab/leaf/tree/master/data/celeba" for details.
    We use torch.save, torch.load in this dataset.
    """

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, user_list: list = None):
        super(CelebA, self).__init__(root, transform=transform, target_transform=target_transform)
        self.train = train
        self.user_list = user_list
        self.loader = default_loader

        if download:
            self.download()

        if not self._check_exists():
            raise FileNotFoundError("Dataset not found. You can use download=True to download it")

        self.total_num_users = torch.load(os.path.join(self.processed_folder, "num_users.pt"))

        if self.user_list is not None:
            self.num_users = len(self.user_list)
        else:
            self.user_list = list(range(self.total_num_users))
            self.num_users = self.total_num_users

        if self.train:
            self.img_paths, self.labels = self.load(train=True)

        else:
            self.img_paths, self.labels = self.load(train=False)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        path, target = self.img_paths[index], self.labels[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self):
        return len(self.img_paths)

    @property
    def raw_folder(self):
        return os.path.join(self.root, "raw")

    @property
    def processed_folder(self):
        return os.path.join(self.root, "processed")

    @property
    def leaf_github_folder(self):
        return os.path.join(self.root, "github_repo")

    def _check_exists(self):
        return (os.path.isdir(os.path.join(self.raw_folder, "img_align_celeba")) and
                os.path.isfile(os.path.join(self.processed_folder, "num_users.pt")) and
                os.path.isfile(os.path.join(self.processed_folder, "train_meta.pt")) and
                os.path.isfile(os.path.join(self.processed_folder, "test_meta.pt"))
                )

    def download(self):
        if self._check_exists():
            # print("Data already downloaded and processed.")
            return

        if os.path.isdir(self.raw_folder) and len(os.listdir(self.raw_folder)) != 0:
            self.process()
            return
        elif not os.path.isdir(self.raw_folder):
            os.makedirs(self.raw_folder, exist_ok=False)

        print("Please download the required files to the \"raw\" subdirectory following the instructions from "
              r"https://github.com/TalwalkarLab/leaf/tree/master/data/celeba. Exiting...")
        exit()

    def process(self):
        print("Processing data...")

        if not os.path.exists(os.path.join(self.raw_folder, "img_align_celeba")):
            pass

        if not os.path.isdir(self.processed_folder):
            os.makedirs(self.processed_folder)

        root = self.root
        assert not os.path.isdir(os.path.join(self.processed_folder, "train"))
        assert not os.path.isdir(os.path.join(self.processed_folder, "test"))
        if not os.path.exists(self.leaf_github_folder):
            os.system(rf"git clone https://gitee.com/hld2018/leaf.git "
                      rf"{root}/github_repo")
        os.makedirs(f"{root}/github_repo/data/celeba/data/raw", exist_ok=True)
        os.system(rf"mv {root}/raw/img_align_celeba {root}/github_repo/data/celeba/data/raw"
                  rf"&& mv {root}/raw/identity_CelebA.txt {root}/github_repo/data/celeba/data/raw"
                  rf"&& mv {root}/raw/list_attr_celeba.txt {root}/github_repo/data/celeba/data/raw"
                  rf"&& cd {root}/github_repo/data/celeba"
                  r"&& ./preprocess.sh -s niid --sf 1. -k 0 -t sample"
                  r"&& cd ../../.."
                  r"&& mv github_repo/data/celeba/data/train processed/"
                  r"&& mv github_repo/data/celeba/data/test processed/"
                  r"&& mv github_repo/data/celeba/data/raw/img_align_celeba raw/"
                  r"&& mv github_repo/data/celeba/data/raw/identity_CelebA.txt raw/"
                  r"&& mv github_repo/data/celeba/data/raw/list_attr_celeba.txt raw/"
                  r"&& rm -rf github_repo")

        # train data
        total_users_train = 0
        list_train_f = [f for f in os.listdir(os.path.join(self.processed_folder, "train")) if
                        fnmatch.fnmatch(f, "*.json")]
        assert len(list_train_f) == 1
        filename = list_train_f[0]

        renamed_users_train = []
        with open(os.path.join(self.processed_folder, "train", filename)) as file:
            data = json.load(file)
            sorted_user_name = sorted(data["user_data"].keys())
            for user_name in sorted_user_name:
                val = data["user_data"][user_name]
                renamed_users_train.append(val)

                total_users_train += 1

        # test data
        total_users_test = 0
        list_test_f = [f for f in os.listdir(os.path.join(self.processed_folder, "test")) if
                       fnmatch.fnmatch(f, "*.json")]
        assert len(list_test_f) == 1
        filename = list_test_f[0]

        renamed_users_test = []
        with open(os.path.join(self.processed_folder, "test", filename)) as file:
            data = json.load(file)
            sorted_user_name = sorted(data["user_data"].keys())
            for user_name in sorted_user_name:
                val = data["user_data"][user_name]
                renamed_users_test.append(val)

                total_users_test += 1

        assert total_users_train == total_users_test
        torch.save(total_users_train, os.path.join(self.processed_folder, "num_users.pt"))
        torch.save(renamed_users_train, os.path.join(self.processed_folder, "train_meta.pt"))
        torch.save(renamed_users_test, os.path.join(self.processed_folder, "test_meta.pt"))
        print("Done. {} users processed.".format(total_users_train))

    def load(self, train):
        if train:
            prf = "train"
        else:
            prf = "test"

        meta_data = torch.load(os.path.join(self.processed_folder, "{}_meta.pt".format(prf)))

        path_list, label_list = [], []
        for user_id in self.user_list:
            x, y = meta_data[user_id]["x"], meta_data[user_id]["y"]
            path_list.extend([os.path.join(self.raw_folder, "img_align_celeba", p) for p in x])
            label_list.extend(y)
        return path_list, label_list


class ImageNet100(VisionDataset):
    def __init__(self, root, data_type, transform=None, target_transform=None, download=False):
        super().__init__(root, transform=transform, target_transform=target_transform)
        assert data_type in ["train", "val"], "data_type must be 'train' or 'val'."

        self.root = root
        self.lmdb_train_path = os.path.join(self.root, "imagenet100_train.lmdb")
        self.lmdb_val_path = os.path.join(self.root, "imagenet100_val.lmdb")

        if download:
            self.download()


        if not self._check_exists():
            raise FileNotFoundError("Dataset not found. Use download=True to download it.")

        if data_type == "train":
            path_to_data = self.lmdb_train_path
        elif data_type == "val":
            path_to_data = self.lmdb_val_path

        self.dataset = TinyImageNetDataset(path_to_data, transform=transform)
        self.targets = self.dataset.labels


    def download(self):
        if self._check_exists():
            print("Data already downloaded and processed.")
            return

        if not os.path.isdir(self.root):
            os.mkdir(self.root)

        raise NotImplementedError("please download from kaggele: https://www.kaggle.com/c/imagenet-object-localization-challenge/data, then process")

    def process(self):
        import os

        # -------------------------
        # 参数设置
        # -------------------------
        TRAIN_FILE = './ILSVRC/ImageSets/CLS-LOC/train_cls.txt'
        VAL_SOLUTION_FILE = './ILSVRC/LOC_val_solution.csv'
        MAPPING_FILE = './ILSVRC/LOC_synset_mapping.txt'
        TRAIN_IMAGE_DIR = './ILSVRC/Data/CLS-LOC/train'
        VAL_IMAGE_DIR = './ILSVRC/Data/CLS-LOC/val'
        TRAIN_LMDB_PATH = '../ImageNet100/imagenet100_train.lmdb'
        VAL_LMDB_PATH = '../ImageNet100/imagenet100_val.lmdb'
        IMAGE_SIZE = (168, 168)

        # -------------------------
        # 第一步：解析类别映射 (LOC_synset_mapping.txt)
        # -------------------------
        # 提取前100个类别的WordNet ID
        top_100_wnids = []
        with open(MAPPING_FILE, 'r') as f:
            for idx, line in enumerate(f):
                if idx >= 100:
                    break
                wnid = line.strip().split(' ')[0]
                top_100_wnids.append(wnid)

        print("前100个类别的 WordNet ID:", top_100_wnids)

        # -------------------------
        # 第二步：解析训练集 (train_cls.txt)
        # -------------------------

        train_image_list = []

        for wnid in top_100_wnids:
            class_dir = os.path.join(TRAIN_IMAGE_DIR, wnid)
            if not os.path.exists(class_dir):
                print(f"⚠️ Warning: Class directory {class_dir} does not exist.")
                continue

            for image_name in os.listdir(class_dir):
                if image_name.endswith('.JPEG'):
                    image_path = os.path.join(class_dir, image_name)
                    train_image_list.append((image_path, wnid))

        # -------------------------
        # 第三步：解析验证集 (val_solution.csv)
        # -------------------------
        val_image_list = []
        with open(VAL_SOLUTION_FILE, 'r') as f:
            lines = f.readlines()[1:]  # 跳过标题行
            for line in lines:
                parts = line.strip().split(',', 1)
                if len(parts) < 2:
                    continue

                image_id = parts[0]
                prediction_string = parts[1]

                wnid = prediction_string.split()[0]

                if wnid in top_100_wnids:
                    image_path = os.path.join(VAL_IMAGE_DIR, f"{image_id}.JPEG")
                    val_image_list.append((image_path, wnid))

        print(f"筛选到属于前100类的验证集图像数量: {len(val_image_list)}")
        print("验证集示例数据:", val_image_list[:5])

        # 转换为 LMDB 格式
        # 转换训练集
        self.convert_to_lmdb(top_100_wnids, train_image_list, TRAIN_LMDB_PATH, IMAGE_SIZE)

        # 转换验证集
        self.convert_to_lmdb(top_100_wnids, val_image_list, VAL_LMDB_PATH, IMAGE_SIZE)

    def convert_to_lmdb(top_100_wnids, image_list, lmdb_path, image_size):
        """
        将图像和标签转换为 LMDB 格式，并将标签从 WordNet ID 映射为整数索引

        Args:
            image_list (list): 包含 (image_path, wnid) 元组的列表
            lmdb_path (str): LMDB 文件的保存路径
            image_size (tuple): 调整后的图像尺寸 (width, height)
        """
        # 检查并创建 LMDB 存储目录
        lmdb_dir = os.path.dirname(lmdb_path)
        if not os.path.exists(lmdb_dir):
            os.makedirs(lmdb_dir)
            print(f"🔄 目录 {lmdb_dir} 不存在，已自动创建。")

        # 1. 提取所有类别并建立映射表（WordNet ID -> 整数索引）
        wnids = top_100_wnids  # 提取唯一的类别ID
        label_mapping = {wnid: idx for idx, wnid in enumerate(wnids)}  # 映射为整数索引
        print("📚 类别映射表:", label_mapping)

        # 2. 估算 LMDB 大小
        map_size = len(image_list) * 3 * image_size[0] * image_size[1] * 2
        env = lmdb.open(lmdb_path, map_size=map_size)

        # 3. 写入 LMDB
        with env.begin(write=True) as txn:
            for idx, (image_path, wnid) in enumerate(tqdm(image_list, desc=f"Writing to {lmdb_path}")):
                if not os.path.exists(image_path):
                    print(f"⚠️ Warning: Image not found {image_path}")
                    continue

                # 读取和处理图像
                image = cv2.imread(image_path)
                if image is None:
                    print(f"❌ Error: Unable to read image {image_path}")
                    continue

                image = cv2.resize(image, image_size)
                _, buffer = cv2.imencode('.jpg', image)

                # 存储图像数据
                txn.put(f"image-{idx:05d}".encode(), buffer.tobytes())

                # 存储整数标签
                label = label_mapping.get(wnid, -1)  # 获取整数索引
                txn.put(f"label-{idx:05d}".encode(), np.array(label, dtype=np.int64).tobytes())

        print(f"✅ LMDB 数据库创建完成: {lmdb_path}")


    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

    def _check_exists(self):
        return os.path.exists(self.lmdb_train_path) and os.path.exists(self.lmdb_val_path)


import lmdb
import cv2
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

import lmdb
import cv2
import numpy as np
from PIL import Image
from torchvision.datasets import VisionDataset


class TinyImageNetDataset(VisionDataset):
    def __init__(self, lmdb_path, transform=None):
        super().__init__(lmdb_path, transform=transform)
        self.lmdb_path = lmdb_path
        self.transform = transform

        # 临时打开一次 env 和 txn，用于读取所有 keys 和 labels（之后关闭）
        env = lmdb.open(lmdb_path, readonly=True, lock=False, readahead=False, meminit=False)
        with env.begin(write=False) as txn:
            cursor = txn.cursor()
            self.keys = []
            self.labels = []
            for key, _ in cursor:
                key_str = key.decode()
                if key_str.startswith("image"):
                    self.keys.append(key_str)
                    label_key = key_str.replace("image", "label").encode()
                    label = txn.get(label_key)
                    if label:
                        try:
                            label = np.frombuffer(label, dtype=np.int64)[0]
                            self.labels.append(label)
                        except Exception as e:
                            print(f"❌ Error decoding label for {key}: {e}")
                            self.labels.append(-1)
                    else:
                        self.labels.append(-1)
        env.close()  # 不再持有 env

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        env = lmdb.open(self.lmdb_path, readonly=True, lock=False, readahead=False, meminit=False)
        with env.begin(write=False) as txn:
            image_key = self.keys[idx].encode()
            image_buffer = txn.get(image_key)
            if image_buffer is None:
                raise ValueError(f"Missing image data for key: {image_key}")

            image = cv2.imdecode(np.frombuffer(image_buffer, np.uint8), cv2.IMREAD_COLOR)
            if image is None:
                raise ValueError(f"Failed to decode image data for key: {image_key}")

            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(image)

            label = self.labels[idx]

        env.close()

        if self.transform:
            image = self.transform(image)

        return image, label


class TinyImageNet(VisionDataset):
    def __init__(self, root, data_type, transform=None, target_transform=None, download=False):
        super().__init__(root, transform=transform, target_transform=target_transform)
        assert data_type in ["train", "val"], "data_type must be 'train' or 'val'."
        self.parent_dir = root
        self.root = root
        self.train_folder = os.path.join(self.root, "tiny-imagenet-200", "train")
        self.val_folder = os.path.join(self.root, "tiny-imagenet-200", "val")
        self.lmdb_train_path = os.path.join(self.root, "tiny_imagenet_train.lmdb")
        self.lmdb_val_path = os.path.join(self.root, "tiny_imagenet_val.lmdb")

        if download:
            self.download()

        if not self._check_exists():
            raise FileNotFoundError("Dataset not found. Use download=True to download it.")

        if data_type == "train":
            path_to_data = self.lmdb_train_path
        elif data_type == "val":
            path_to_data = self.lmdb_val_path

        self.dataset = TinyImageNetDataset(path_to_data, transform=transform)
        self.targets = self.dataset.labels

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

    def _check_exists(self):
        return os.path.exists(self.lmdb_train_path) and os.path.exists(self.lmdb_val_path)

    def download(self):
        if self._check_exists():
            print("Data already downloaded.")
            return

        if not os.path.isdir(self.root):
            os.mkdir(self.root)

        os.system(rf"cd {self.root} "
                  r"&& wget -nc http://cs231n.stanford.edu/tiny-imagenet-200.zip "
                  r"&& unzip -o tiny-imagenet-200.zip")

        self.process()

    def process(self):
        print("Processing validation set...")

        # 确保 val_annotations.txt 存在
        annotations_path = os.path.join(self.val_folder, "val_annotations.txt")
        if not os.path.exists(annotations_path):
            raise FileNotFoundError(f"Missing val_annotations.txt at {annotations_path}")

        # 检查是否已完成类别子目录的处理
        categories = [d for d in os.listdir(self.val_folder) if os.path.isdir(os.path.join(self.val_folder, d))]
        if len(categories) > 0 and "images" not in categories:
            print("Validation set already processed into categories. Skipping reorganization.")
        else:
            # 按照 val_annotations.txt 组织图片到类别子目录
            file_to_class = {}
            with open(annotations_path, 'r') as f:
                for line in f.readlines():
                    split_line = line.split('\t')
                    file_name, class_name = split_line[0], split_line[1]
                    file_to_class[file_name] = class_name
                    dir_path = os.path.join(self.val_folder, class_name)
                    if not os.path.exists(dir_path):
                        os.mkdir(dir_path)

            # 移动图片到对应类别文件夹
            image_folder_path = os.path.join(self.val_folder, "images")
            all_imgs = os.listdir(image_folder_path)
            for img_name in all_imgs:
                if img_name in file_to_class:
                    img_class = file_to_class[img_name]
                    move(os.path.join(image_folder_path, img_name), os.path.join(self.val_folder, img_class, img_name))

            # 删除空的 images 文件夹
            os.rmdir(image_folder_path)
            print("Validation set processing complete.")

        # 创建 class_to_idx 映射（基于训练集的类别）
        classes = sorted(os.listdir(self.train_folder))  # 确保类别顺序与训练集一致
        class_to_idx = {class_name: idx for idx, class_name in enumerate(classes)}

        # 转换为 LMDB 格式
        self.convert_to_lmdb(self.train_folder, self.lmdb_train_path, dataset_type="train", class_to_idx=class_to_idx)
        self.convert_to_lmdb(self.val_folder, self.lmdb_val_path, dataset_type="val", class_to_idx=class_to_idx)

    def convert_to_lmdb(self, image_folder, lmdb_path, dataset_type="train", class_to_idx=None):
        print(f"Converting {dataset_type} dataset to LMDB format...")
        image_paths = []
        labels = []

        if dataset_type == "train":
            # 遍历类别子目录
            for class_name in os.listdir(image_folder):
                class_folder = os.path.join(image_folder, class_name)
                if not os.path.isdir(class_folder):
                    continue  # 跳过非目录项

                # 遍历文件夹中的图片
                for file in os.listdir(class_folder):
                    file_path = os.path.join(class_folder, file)
                    if os.path.isfile(file_path):
                        image_paths.append(file_path)
                        if class_to_idx:
                            labels.append(class_to_idx[class_name])
                        else:
                            print(f"Warning: Class name {class_name} not found in class_to_idx mapping.")
                            labels.append(-1)  # 默认标签 -1

        elif dataset_type == "val":
            # 遍历类别子目录
            for class_name in os.listdir(image_folder):
                class_folder = os.path.join(image_folder, class_name)
                if not os.path.isdir(class_folder):
                    continue  # 跳过非目录项

                for file in os.listdir(class_folder):
                    file_path = os.path.join(class_folder, file)
                    if os.path.isfile(file_path):
                        image_paths.append(file_path)
                        if class_to_idx:
                            labels.append(class_to_idx[class_name])
                        else:
                            print(f"Warning: Class name {class_name} not found in class_to_idx mapping.")
                            labels.append(-1)  # 默认标签 -1

        elif dataset_type == "test":
            # 测试集无标签
            for file in os.listdir(image_folder):
                file_path = os.path.join(image_folder, file)
                image_paths.append(file_path)
                labels.append(-1)  # 测试集无标签

        # 创建 LMDB 数据库
        map_size = len(image_paths) * 3 * 64 * 64 * 2  # 预估 LMDB 大小
        env = lmdb.open(lmdb_path, map_size=map_size)

        with env.begin(write=True) as txn:
            for idx, (image_path, label) in enumerate(tqdm(zip(image_paths, labels), total=len(image_paths))):
                image = cv2.imread(image_path)
                if image is None:  # 检查是否成功读取图片
                    print(f"Error: Unable to read image {image_path}")
                    continue  # 跳过错误图片

                image = cv2.resize(image, (64, 64))  # 确保统一大小
                _, buffer = cv2.imencode('.jpg', image)
                txn.put(f"image-{idx:05d}".encode(), buffer.tobytes())
                if label != -1:
                    txn.put(f"label-{idx:05d}".encode(), np.array(label, dtype=np.int64).tobytes())

        print(f"{dataset_type.capitalize()} dataset converted to LMDB successfully.")




