import os
from io import BytesIO
from pathlib import Path
import time
import lmdb
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10, LSUNClass
import torch
import torch.nn.functional as F
import pandas as pd

import torchvision.transforms.functional as Ftrans

import os
import random
from torch.utils.data import Dataset
from PIL import Image
from skimage.transform import PiecewiseAffineTransform, warp

import numpy as np
import cv2
from insightface.app import FaceAnalysis
import pickle


def tps_warp_image(src_img, src_points, dst_points, dst_shape):
    """
    基于分段仿射变换 (Piecewise Affine Transform) 对 src_img 进行变形，src_points -> dst_points
    src_points 和 dst_points 是 Nx2 的数组
    dst_shape 是目标图像大小 (height, width)
    """
    tform = PiecewiseAffineTransform()
    tform.estimate(dst_points, src_points)  # 注意这里的顺序是 dst -> src，因为 warp 函数是反向映射

    warped = warp(src_img, tform, output_shape=dst_shape, mode='edge')
    warped = (warped * 255).astype(np.uint8)  # 转换回 0-255 范围

    return warped


class FacePairDatasetTest(Dataset):
    def __init__(self, path, mode='test', eval_mode='normal', cache_dir=None):
        """
        root_dir: 数据根目录，结构为 root_dir/ids/{face1, face2, ...}
        """
        self.root_dir = path

        if cache_dir is None:
            parent_dir = os.path.dirname(path)
            self.cache_dir = os.path.join(parent_dir, 'cache')
        else:
            self.cache_dir = cache_dir

        if not os.path.exists(self.cache_dir):
            os.makedirs(self.cache_dir)

        self.mode = mode
        self.eval_mode = eval_mode
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        self.face_app = FaceAnalysis(name='antelopev2/antelopev2', root='./checkpoints', providers=['CPUExecutionProvider'])
        self.face_app.prepare(ctx_id=0, det_size=(256, 256))

        self.id_dirs = [
            os.path.join(path, d) for d in os.listdir(path)
            if os.path.isdir(os.path.join(path, d))
        ]
        if not self.id_dirs:
            raise ValueError("没有找到任何 ID 文件夹")

        for id_dir in self.id_dirs:
            cache_id_dir = self._get_cache_path(id_dir)
            if not os.path.exists(cache_id_dir):
                os.makedirs(cache_id_dir)

    def extract_id_emb(self, img: Image.Image, image_path=None) -> torch.Tensor:
        """提取人脸嵌入向量，优先从缓存加载"""
        # 如果提供了图片路径，尝试从缓存加载
        if image_path is not None:
            cached_id_emb = self._load_cached_id_emb(image_path)
            if cached_id_emb is not None:
                return cached_id_emb

        # 缓存中没有，则计算id_emb
        np_img = np.array(img)[:, :, ::-1]  # PIL -> BGR
        faces = self.face_app.get(np_img)
        largest_face = sorted(faces, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
        id_emb = torch.tensor(largest_face['embedding'], dtype=torch.float32)[None]
        id_emb = id_emb / torch.norm(id_emb, dim=1, keepdim=True)  # normalize

        # 如果提供了图片路径，保存到缓存
        if image_path is not None:
            self._save_id_emb_cache(image_path, id_emb)

        return id_emb

    def __len__(self):
        # 读取测试集时根据ID文件夹数返回长度
        return sum(len(os.listdir(id_dir)) for id_dir in self.id_dirs)

    def _get_cache_path(self, original_path):
        rel_path = os.path.relpath(original_path, self.root_dir)
        return os.path.join(self.cache_dir, rel_path)

    def _get_landmark_cache_file(self, image_path):
        rel_path = os.path.relpath(image_path, self.root_dir)
        cache_file = os.path.splitext(rel_path)[0] + '_landmarks.pkl'
        return os.path.join(self.cache_dir, cache_file)

    def _get_id_emb_cache_file(self, image_path):
        """获取id_emb缓存文件路径"""
        rel_path = os.path.relpath(image_path, self.root_dir)
        cache_file = os.path.splitext(rel_path)[0] + '_id_emb.pkl'
        return os.path.join(self.cache_dir, cache_file)

    def _load_cached_landmarks(self, image_path):
        """加载缓存的landmarks"""
        cache_file = self._get_landmark_cache_file(image_path)

        if not os.path.exists(cache_file):
            return None

        try:
            with open(cache_file, 'rb') as f:
                landmarks = pickle.load(f)
            return landmarks

        except Exception as e:
            print(f"加载landmarks缓存失败 {cache_file}: {e}")
            return None

    def _load_cached_id_emb(self, image_path):
        """加载缓存的id_emb"""
        cache_file = self._get_id_emb_cache_file(image_path)

        if not os.path.exists(cache_file):
            return None

        try:
            with open(cache_file, 'rb') as f:
                id_emb = pickle.load(f)
            return id_emb

        except Exception as e:
            print(f"加载id_emb缓存失败 {cache_file}: {e}")
            return None

    def _save_landmarks_cache(self, image_path, landmarks):
        """保存landmarks到缓存"""
        cache_file = self._get_landmark_cache_file(image_path)

        # 确保缓存目录存在
        cache_dir = os.path.dirname(cache_file)
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)

        try:
            with open(cache_file, 'wb') as f:
                pickle.dump(landmarks, f)

        except Exception as e:
            print(f"保存landmarks缓存失败 {cache_file}: {e}")

    def _save_id_emb_cache(self, image_path, id_emb):
        """保存id_emb到缓存"""
        cache_file = self._get_id_emb_cache_file(image_path)

        # 确保缓存目录存在
        cache_dir = os.path.dirname(cache_file)
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)

        try:
            with open(cache_file, 'wb') as f:
                pickle.dump(id_emb, f)

        except Exception as e:
            print(f"保存id_emb缓存失败 {cache_file}: {e}")

    def get_landmarks(self, img, image_path=None):
        """获取landmarks，优先从缓存加载"""
        # 如果提供了图片路径，尝试从缓存加载
        if image_path is not None:
            cached_landmarks = self._load_cached_landmarks(image_path)
            if cached_landmarks is not None:
                return cached_landmarks

        # 缓存中没有，则计算landmarks
        faces = self.face_app.get(img)
        if not faces:
            raise ValueError(f"未检测到人脸: {image_path if image_path else 'unknown'}")

        landmarks = faces[0].landmark_2d_106.astype(np.float32)

        # 如果提供了图片路径，保存到缓存
        if image_path is not None:
            self._save_landmarks_cache(image_path, landmarks)

        return landmarks

    def __getitem__(self, idx):
        # 获取某个ID文件夹
        id_dir = self.id_dirs[idx % len(self.id_dirs)]
        # 获取该ID下的所有图片
        face_files = [
            os.path.join(id_dir, f) for f in os.listdir(id_dir)
            if os.path.isfile(os.path.join(id_dir, f))
               and f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))
        ]

        # 选择两张图片
        if len(face_files) >= 2:
            face_paths = random.sample(face_files, 2)

            hr = Image.open(face_paths[0]).convert("RGB")
            hr_path = face_paths[0]
            print(hr_path)

            if self.eval_mode == 'normal':
                ref_path = '/data/yangjiarui/diffae/datasets/sr_test/00001/10.png'

                ref = Image.open(ref_path).convert("RGB")
                print(ref_path)
            else:
                ref_id = random.choice(self.id_dirs)
                ref_files = [
                    os.path.join(ref_id, f) for f in os.listdir(ref_id)
                    if os.path.isfile(os.path.join(ref_id, f))
                       and f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))
                ]
                ref_path = random.choice(ref_files)
                ref = Image.open(ref_path).convert("RGB")  # 修复了这里的变量名错误

            hr = hr.resize((256, 256), Image.BICUBIC)
            ref = ref.resize((256, 256), Image.BICUBIC)
            lr_lowers = hr.resize((16, 16), Image.BICUBIC)
            lr = lr_lowers.resize((256, 256), Image.BICUBIC)

            try:
                pts1 = self.get_landmarks(np.array(ref), ref_path)
                pts2 = self.get_landmarks(np.array(hr), hr_path)
            except ValueError as e:
                print(f"跳过图片对，原因: {e}")
                return self.__getitem__(idx)  # 重新尝试

            warp = tps_warp_image(np.array(ref), pts1, pts2, (lr.size[1], lr.size[0]))
            warp = Image.fromarray(warp)
            hull2 = cv2.convexHull(pts2.astype(np.int32))
            mask = np.zeros_like(np.array(hr), dtype=np.float32)
            cv2.fillConvexPoly(mask, hull2, (1.0, 1.0, 1.0))
            warp_img = np.array(warp) * mask + np.array(lr) * (1 - mask)
            lr_mask = np.array(lr) * (1 - mask)
            lr_mask = np.clip(lr_mask, 0, 255).astype(np.uint8)
            warp_img_c = np.clip(warp_img, 0, 255).astype(np.uint8)
            warp_img_gray = cv2.cvtColor(warp_img_c, cv2.COLOR_RGB2GRAY)
            warp_img_gray_3ch = cv2.merge([warp_img_gray] * 3)  # shape: [H, W, 3]
            warp_tensor_gray = self.transform(Image.fromarray(warp_img_gray_3ch))

            hr_tensor = self.transform(hr)
            ref_tensor = self.transform(ref)
            lr_tensor = self.transform(lr_mask)
            warp_img = self.transform(warp_img_c)

            try:
                if self.eval_mode == 'normal':
                    id_emb = self.extract_id_emb(ref, None)
                elif self.eval_mode == 'cross':
                    id_emb = self.extract_id_emb(ref, None)
            except Exception as e:
                print(f"跳过样本，原因: {e}")
                return self.__getitem__(idx)  # 重新尝试

            return hr_tensor, lr_tensor, ref_tensor, warp_img, id_emb, hr_path



class FacePairDataset(Dataset):
    def __init__(self, path, mode='train', eval_mode='normal', cache_dir=None):
        """
        root_dir: 数据根目录，结构为 root_dir/ids/{face1, face2, ...}
        """
        self.root_dir = path

        if cache_dir is None:
            parent_dir = os.path.dirname(path)
            self.cache_dir = os.path.join(parent_dir, 'cache')
        else:
            self.cache_dir = cache_dir

        if not os.path.exists(self.cache_dir):
            os.makedirs(self.cache_dir)

        self.mode = mode
        self.eval_mode = eval_mode
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        self.face_app = FaceAnalysis(name='antelopev2/antelopev2', root='./checkpoints', providers=['CPUExecutionProvider'])
        self.face_app.prepare(ctx_id=0, det_size=(256, 256))

        self.id_dirs = [
            os.path.join(path, d) for d in os.listdir(path)
            if os.path.isdir(os.path.join(path, d))
        ]
        if not self.id_dirs:
            raise ValueError("没有找到任何 ID 文件夹")

        for id_dir in self.id_dirs:
            cache_id_dir = self._get_cache_path(id_dir)
            if not os.path.exists(cache_id_dir):
                os.makedirs(cache_id_dir)

    def extract_id_emb(self, img: Image.Image, image_path=None) -> torch.Tensor:
        """提取人脸嵌入向量，优先从缓存加载"""
        # 如果提供了图片路径，尝试从缓存加载
        if image_path is not None:
            cached_id_emb = self._load_cached_id_emb(image_path)
            if cached_id_emb is not None:
                return cached_id_emb

        # 缓存中没有，则计算id_emb
        np_img = np.array(img)[:, :, ::-1]  # PIL -> BGR
        faces = self.face_app.get(np_img)
        largest_face = sorted(faces, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
        id_emb = torch.tensor(largest_face['embedding'], dtype=torch.float32)[None]
        id_emb = id_emb / torch.norm(id_emb, dim=1, keepdim=True)  # normalize

        # 如果提供了图片路径，保存到缓存
        if image_path is not None:
            self._save_id_emb_cache(image_path, id_emb)

        return id_emb

    def __len__(self):
        if self.mode == 'train':
            return 100000  # 可以根据需要设置
        else:
            return 1

    def _get_cache_path(self, original_path):
        rel_path = os.path.relpath(original_path, self.root_dir)
        return os.path.join(self.cache_dir, rel_path)

    def _get_landmark_cache_file(self, image_path):
        rel_path = os.path.relpath(image_path, self.root_dir)
        cache_file = os.path.splitext(rel_path)[0] + '_landmarks.pkl'
        return os.path.join(self.cache_dir, cache_file)

    def _get_id_emb_cache_file(self, image_path):
        """获取id_emb缓存文件路径"""
        rel_path = os.path.relpath(image_path, self.root_dir)
        cache_file = os.path.splitext(rel_path)[0] + '_id_emb.pkl'
        return os.path.join(self.cache_dir, cache_file)

    def _load_cached_landmarks(self, image_path):
        """加载缓存的landmarks"""
        cache_file = self._get_landmark_cache_file(image_path)

        if not os.path.exists(cache_file):
            return None

        try:
            with open(cache_file, 'rb') as f:
                landmarks = pickle.load(f)
            return landmarks

        except Exception as e:
            print(f"加载landmarks缓存失败 {cache_file}: {e}")
            return None

    def _load_cached_id_emb(self, image_path):
        """加载缓存的id_emb"""
        cache_file = self._get_id_emb_cache_file(image_path)

        if not os.path.exists(cache_file):
            return None

        try:
            with open(cache_file, 'rb') as f:
                id_emb = pickle.load(f)
            return id_emb

        except Exception as e:
            print(f"加载id_emb缓存失败 {cache_file}: {e}")
            return None

    def _save_landmarks_cache(self, image_path, landmarks):
        """保存landmarks到缓存"""
        cache_file = self._get_landmark_cache_file(image_path)

        # 确保缓存目录存在
        cache_dir = os.path.dirname(cache_file)
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)

        try:
            with open(cache_file, 'wb') as f:
                pickle.dump(landmarks, f)

        except Exception as e:
            print(f"保存landmarks缓存失败 {cache_file}: {e}")

    def _save_id_emb_cache(self, image_path, id_emb):
        """保存id_emb到缓存"""
        cache_file = self._get_id_emb_cache_file(image_path)

        # 确保缓存目录存在
        cache_dir = os.path.dirname(cache_file)
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)

        try:
            with open(cache_file, 'wb') as f:
                pickle.dump(id_emb, f)

        except Exception as e:
            print(f"保存id_emb缓存失败 {cache_file}: {e}")

    def get_landmarks(self, img, image_path=None):
        """获取landmarks，优先从缓存加载"""
        # 如果提供了图片路径，尝试从缓存加载
        if image_path is not None:
            cached_landmarks = self._load_cached_landmarks(image_path)
            if cached_landmarks is not None:
                return cached_landmarks

        # 缓存中没有，则计算landmarks
        faces = self.face_app.get(img)
        if not faces:
            raise ValueError(f"未检测到人脸: {image_path if image_path else 'unknown'}")

        landmarks = faces[0].landmark_2d_106.astype(np.float32)

        # 如果提供了图片路径，保存到缓存
        if image_path is not None:
            self._save_landmarks_cache(image_path, landmarks)

        return landmarks

    def __getitem__(self, idx):
        while True:
            id_dir = random.choice(self.id_dirs)
            face_files = [
                os.path.join(id_dir, f) for f in os.listdir(id_dir)
                if os.path.isfile(os.path.join(id_dir, f))
                   and f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))
            ]
            if len(face_files) >= 2:
                face_paths = random.sample(face_files, 2)

                hr = Image.open(face_paths[0]).convert("RGB")
                hr_path = face_paths[0]

                if self.eval_mode == 'normal':
                    ref = Image.open(face_paths[1]).convert("RGB")
                    ref_path = face_paths[1]
                else:
                    ref_id = random.choice(self.id_dirs)
                    ref_files = [
                        os.path.join(ref_id, f) for f in os.listdir(ref_id)
                        if os.path.isfile(os.path.join(ref_id, f))
                           and f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))
                    ]
                    ref_path = random.choice(ref_files)
                    ref = Image.open(ref_path).convert("RGB")  # 修复了这里的变量名错误


                hr = hr.resize((256, 256), Image.BICUBIC)
                ref = ref.resize((256, 256), Image.BICUBIC)
                lr_lowers = hr.resize((32, 32), Image.BICUBIC)
                lr = lr_lowers.resize((256, 256), Image.BICUBIC)

                try:
                    pts1 = self.get_landmarks(np.array(ref), ref_path)
                    pts2 = self.get_landmarks(np.array(hr), face_paths[0])
                except ValueError as e:
                    print(f"跳过图片对，原因: {e}")
                    continue

                warp = tps_warp_image(np.array(ref), pts1, pts2, (lr.size[1], lr.size[0]))
                warp = Image.fromarray(warp)
                hull2 = cv2.convexHull(pts2.astype(np.int32))
                mask = np.zeros_like(np.array(hr), dtype=np.float32)
                cv2.fillConvexPoly(mask, hull2, (1.0, 1.0, 1.0))
                warp_img = np.array(warp) * mask + np.array(lr) * (1 - mask)
                lr_mask = np.array(lr) * (1 - mask)
                lr_mask = np.clip(lr_mask, 0, 255).astype(np.uint8)
                warp_img_c = np.clip(warp_img, 0, 255).astype(np.uint8)
                warp_img_gray = cv2.cvtColor(warp_img_c, cv2.COLOR_RGB2GRAY)
                warp_img_gray_3ch = cv2.merge([warp_img_gray] * 3)  # shape: [H, W, 3]
                warp_tensor_gray = self.transform(Image.fromarray(warp_img_gray_3ch))

                hr_tensor = self.transform(hr)
                ref_tensor = self.transform(ref)
                lr_tensor = self.transform(lr_mask)
                warp_img = self.transform(warp_img_c)
                lq = self.transform(lr)

                try:
                    if self.eval_mode == 'normal':
                        id_emb = self.extract_id_emb(hr, hr_path)
                    elif self.eval_mode == 'cross':
                        id_emb = self.extract_id_emb(ref, ref_path)
                except Exception as e:
                    print(f"跳过样本，原因: {e}")
                    continue

                return hr_tensor, lr_tensor, ref_tensor, warp_img, id_emb, lq

    def clear_cache(self):
        """清空所有缓存文件"""
        import shutil
        if os.path.exists(self.cache_dir):
            shutil.rmtree(self.cache_dir)
            os.makedirs(self.cache_dir)
            print("缓存已清空")

    def get_cache_stats(self):
        """获取缓存统计信息"""
        total_images = 0
        cached_images = 0

        for id_dir in self.id_dirs:
            face_files = [
                f for f in os.listdir(id_dir)
                if os.path.isfile(os.path.join(id_dir, f))
                   and f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))
            ]
            total_images += len(face_files)

            for face_file in face_files:
                image_path = os.path.join(id_dir, face_file)
                cache_file = self._get_landmark_cache_file(image_path)
                if os.path.exists(cache_file):
                    cached_images += 1

        cache_ratio = cached_images / total_images if total_images > 0 else 0
        print(f"缓存统计: {cached_images}/{total_images} ({cache_ratio:.1%}) 已缓存")
        return cached_images, total_images, cache_ratio

class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts=['jpg'],
        do_augment: bool = True,
        do_transform: bool = True,
        do_normalize: bool = True,
        sort_names=False,
        has_subdir: bool = True,
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size

        # relative paths (make it shorter, saves memory and faster to sort)
        if has_subdir:
            self.paths = [
                p.relative_to(folder) for ext in exts
                for p in Path(f'{folder}').glob(f'**/*.{ext}')
            ]
        else:
            self.paths = [
                p.relative_to(folder) for ext in exts
                for p in Path(f'{folder}').glob(f'*.{ext}')
            ]
        if sort_names:
            self.paths = sorted(self.paths)

        transform = [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
        ]
        if do_augment:
            transform.append(transforms.RandomHorizontalFlip())
        if do_transform:
            transform.append(transforms.ToTensor())
        if do_normalize:
            transform.append(
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        self.transform = transforms.Compose(transform)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = os.path.join(self.folder, self.paths[index])
        img = Image.open(path)
        # if the image is 'rgba'!
        img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return {'img': img, 'index': index}


class SubsetDataset(Dataset):
    def __init__(self, dataset, size):
        assert len(dataset) >= size
        self.dataset = dataset
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        assert index < self.size
        return self.dataset[index]


class BaseLMDB(Dataset):
    def __init__(self, path, original_resolution, zfill: int = 5):
        self.original_resolution = original_resolution
        self.zfill = zfill
        self.env = lmdb.open(
            path,
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError('Cannot open lmdb dataset', path)

        with self.env.begin(write=False) as txn:
            self.length = int(
                txn.get('length'.encode('utf-8')).decode('utf-8'))

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            key = f'{self.original_resolution}-{str(index).zfill(self.zfill)}'.encode(
                'utf-8')
            img_bytes = txn.get(key)

        buffer = BytesIO(img_bytes)
        img = Image.open(buffer)
        return img


def make_transform(
    image_size,
    flip_prob=0.5,
    crop_d2c=False,
):
    if crop_d2c:
        transform = [
            d2c_crop(),
            transforms.Resize(image_size),
        ]
    else:
        transform = [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
        ]
    transform.append(transforms.RandomHorizontalFlip(p=flip_prob))
    transform.append(transforms.ToTensor())
    transform.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
    transform = transforms.Compose(transform)
    return transform


class FFHQlmdb(Dataset):
    def __init__(self,
                 path=os.path.expanduser('datasets/ffhq256.lmdb'),
                 image_size=256,
                 original_resolution=256,
                 split=None,
                 as_tensor: bool = True,
                 do_augment: bool = True,
                 do_normalize: bool = True,
                 **kwargs):
        self.original_resolution = original_resolution
        self.data = BaseLMDB(path, original_resolution, zfill=5)
        self.length = len(self.data)

        if split is None:
            self.offset = 0
        elif split == 'train':
            # last 60k
            self.length = self.length - 10000
            self.offset = 10000
        elif split == 'test':
            # first 10k
            self.length = 10000
            self.offset = 0
        else:
            raise NotImplementedError()

        transform = [
            transforms.Resize(image_size),
        ]
        if do_augment:
            transform.append(transforms.RandomHorizontalFlip())
        if as_tensor:
            transform.append(transforms.ToTensor())
        if do_normalize:
            transform.append(
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        self.transform = transforms.Compose(transform)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        assert index < self.length
        index = index + self.offset
        img = self.data[index]
        if self.transform is not None:
            img = self.transform(img)
        return {'img': img, 'index': index}


class Crop:
    def __init__(self, x1, x2, y1, y2):
        self.x1 = x1
        self.x2 = x2
        self.y1 = y1
        self.y2 = y2

    def __call__(self, img):
        return Ftrans.crop(img, self.x1, self.y1, self.x2 - self.x1,
                           self.y2 - self.y1)

    def __repr__(self):
        return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
            self.x1, self.x2, self.y1, self.y2)


def d2c_crop():
    # from D2C paper for CelebA dataset.
    cx = 89
    cy = 121
    x1 = cy - 64
    x2 = cy + 64
    y1 = cx - 64
    y2 = cx + 64
    return Crop(x1, x2, y1, y2)


class CelebAlmdb(Dataset):
    """
    also supports for d2c crop.
    """
    def __init__(self,
                 path,
                 image_size,
                 original_resolution=128,
                 split=None,
                 as_tensor: bool = True,
                 do_augment: bool = True,
                 do_normalize: bool = True,
                 crop_d2c: bool = False,
                 **kwargs):
        self.original_resolution = original_resolution
        self.data = BaseLMDB(path, original_resolution, zfill=7)
        self.length = len(self.data)
        self.crop_d2c = crop_d2c

        if split is None:
            self.offset = 0
        else:
            raise NotImplementedError()

        if crop_d2c:
            transform = [
                d2c_crop(),
                transforms.Resize(image_size),
            ]
        else:
            transform = [
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
            ]

        if do_augment:
            transform.append(transforms.RandomHorizontalFlip())
        if as_tensor:
            transform.append(transforms.ToTensor())
        if do_normalize:
            transform.append(
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        self.transform = transforms.Compose(transform)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        assert index < self.length
        index = index + self.offset
        img = self.data[index]
        if self.transform is not None:
            img = self.transform(img)
        return {'img': img, 'index': index}


class Horse_lmdb(Dataset):
    def __init__(self,
                 path=os.path.expanduser('datasets/horse256.lmdb'),
                 image_size=128,
                 original_resolution=256,
                 do_augment: bool = True,
                 do_transform: bool = True,
                 do_normalize: bool = True,
                 **kwargs):
        self.original_resolution = original_resolution
        self.data = BaseLMDB(path, original_resolution, zfill=7)
        self.length = len(self.data)

        transform = [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
        ]
        if do_augment:
            transform.append(transforms.RandomHorizontalFlip())
        if do_transform:
            transform.append(transforms.ToTensor())
        if do_normalize:
            transform.append(
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        self.transform = transforms.Compose(transform)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        img = self.data[index]
        if self.transform is not None:
            img = self.transform(img)
        return {'img': img, 'index': index}


class Bedroom_lmdb(Dataset):
    def __init__(self,
                 path=os.path.expanduser('datasets/bedroom256.lmdb'),
                 image_size=128,
                 original_resolution=256,
                 do_augment: bool = True,
                 do_transform: bool = True,
                 do_normalize: bool = True,
                 **kwargs):
        self.original_resolution = original_resolution
        print(path)
        self.data = BaseLMDB(path, original_resolution, zfill=7)
        self.length = len(self.data)

        transform = [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
        ]
        if do_augment:
            transform.append(transforms.RandomHorizontalFlip())
        if do_transform:
            transform.append(transforms.ToTensor())
        if do_normalize:
            transform.append(
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        self.transform = transforms.Compose(transform)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        img = self.data[index]
        img = self.transform(img)
        return {'img': img, 'index': index}


class CelebAttrDataset(Dataset):

    id_to_cls = [
        '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
        'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
        'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
        'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
        'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
        'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
        'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
        'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
        'Wearing_Necklace', 'Wearing_Necktie', 'Young'
    ]
    cls_to_id = {v: k for k, v in enumerate(id_to_cls)}

    def __init__(self,
                 folder,
                 image_size=64,
                 attr_path=os.path.expanduser(
                     'datasets/celeba_anno/list_attr_celeba.txt'),
                 ext='png',
                 only_cls_name: str = None,
                 only_cls_value: int = None,
                 do_augment: bool = False,
                 do_transform: bool = True,
                 do_normalize: bool = True,
                 d2c: bool = False):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.ext = ext

        # relative paths (make it shorter, saves memory and faster to sort)
        paths = [
            str(p.relative_to(folder))
            for p in Path(f'{folder}').glob(f'**/*.{ext}')
        ]
        paths = [str(each).split('.')[0] + '.jpg' for each in paths]

        if d2c:
            transform = [
                d2c_crop(),
                transforms.Resize(image_size),
            ]
        else:
            transform = [
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
            ]
        if do_augment:
            transform.append(transforms.RandomHorizontalFlip())
        if do_transform:
            transform.append(transforms.ToTensor())
        if do_normalize:
            transform.append(
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        self.transform = transforms.Compose(transform)

        with open(attr_path) as f:
            # discard the top line
            f.readline()
            self.df = pd.read_csv(f, delim_whitespace=True)
            self.df = self.df[self.df.index.isin(paths)]

        if only_cls_name is not None:
            self.df = self.df[self.df[only_cls_name] == only_cls_value]

    def pos_count(self, cls_name):
        return (self.df[cls_name] == 1).sum()

    def neg_count(self, cls_name):
        return (self.df[cls_name] == -1).sum()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        name = row.name.split('.')[0]
        name = f'{name}.{self.ext}'

        path = os.path.join(self.folder, name)
        img = Image.open(path)

        labels = [0] * len(self.id_to_cls)
        for k, v in row.items():
            labels[self.cls_to_id[k]] = int(v)

        if self.transform is not None:
            img = self.transform(img)

        return {'img': img, 'index': index, 'labels': torch.tensor(labels)}


class CelebD2CAttrDataset(CelebAttrDataset):
    """
    the dataset is used in the D2C paper. 
    it has a specific crop from the original CelebA.
    """
    def __init__(self,
                 folder,
                 image_size=64,
                 attr_path=os.path.expanduser(
                     'datasets/celeba_anno/list_attr_celeba.txt'),
                 ext='jpg',
                 only_cls_name: str = None,
                 only_cls_value: int = None,
                 do_augment: bool = False,
                 do_transform: bool = True,
                 do_normalize: bool = True,
                 d2c: bool = True):
        super().__init__(folder,
                         image_size,
                         attr_path,
                         ext=ext,
                         only_cls_name=only_cls_name,
                         only_cls_value=only_cls_value,
                         do_augment=do_augment,
                         do_transform=do_transform,
                         do_normalize=do_normalize,
                         d2c=d2c)


class CelebAttrFewshotDataset(Dataset):
    def __init__(
        self,
        cls_name,
        K,
        img_folder,
        img_size=64,
        ext='png',
        seed=0,
        only_cls_name: str = None,
        only_cls_value: int = None,
        all_neg: bool = False,
        do_augment: bool = False,
        do_transform: bool = True,
        do_normalize: bool = True,
        d2c: bool = False,
    ) -> None:
        self.cls_name = cls_name
        self.K = K
        self.img_folder = img_folder
        self.ext = ext

        if all_neg:
            path = f'data/celeba_fewshots/K{K}_allneg_{cls_name}_{seed}.csv'
        else:
            path = f'data/celeba_fewshots/K{K}_{cls_name}_{seed}.csv'
        self.df = pd.read_csv(path, index_col=0)
        if only_cls_name is not None:
            self.df = self.df[self.df[only_cls_name] == only_cls_value]

        if d2c:
            transform = [
                d2c_crop(),
                transforms.Resize(img_size),
            ]
        else:
            transform = [
                transforms.Resize(img_size),
                transforms.CenterCrop(img_size),
            ]
        if do_augment:
            transform.append(transforms.RandomHorizontalFlip())
        if do_transform:
            transform.append(transforms.ToTensor())
        if do_normalize:
            transform.append(
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        self.transform = transforms.Compose(transform)

    def pos_count(self, cls_name):
        return (self.df[cls_name] == 1).sum()

    def neg_count(self, cls_name):
        return (self.df[cls_name] == -1).sum()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        name = row.name.split('.')[0]
        name = f'{name}.{self.ext}'

        path = os.path.join(self.img_folder, name)
        img = Image.open(path)

        # (1, 1)
        label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1)

        if self.transform is not None:
            img = self.transform(img)

        return {'img': img, 'index': index, 'labels': label}


class CelebD2CAttrFewshotDataset(CelebAttrFewshotDataset):
    def __init__(self,
                 cls_name,
                 K,
                 img_folder,
                 img_size=64,
                 ext='jpg',
                 seed=0,
                 only_cls_name: str = None,
                 only_cls_value: int = None,
                 all_neg: bool = False,
                 do_augment: bool = False,
                 do_transform: bool = True,
                 do_normalize: bool = True,
                 is_negative=False,
                 d2c: bool = True) -> None:
        super().__init__(cls_name,
                         K,
                         img_folder,
                         img_size,
                         ext=ext,
                         seed=seed,
                         only_cls_name=only_cls_name,
                         only_cls_value=only_cls_value,
                         all_neg=all_neg,
                         do_augment=do_augment,
                         do_transform=do_transform,
                         do_normalize=do_normalize,
                         d2c=d2c)
        self.is_negative = is_negative


class CelebHQAttrDataset(Dataset):
    id_to_cls = [
        '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
        'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
        'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
        'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
        'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
        'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
        'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
        'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
        'Wearing_Necklace', 'Wearing_Necktie', 'Young'
    ]
    cls_to_id = {v: k for k, v in enumerate(id_to_cls)}

    def __init__(self,
                 path=os.path.expanduser('datasets/celebahq256.lmdb'),
                 image_size=None,
                 attr_path=os.path.expanduser(
                     'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
                 original_resolution=256,
                 do_augment: bool = False,
                 do_transform: bool = True,
                 do_normalize: bool = True):
        super().__init__()
        self.image_size = image_size
        self.data = BaseLMDB(path, original_resolution, zfill=5)

        transform = [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
        ]
        if do_augment:
            transform.append(transforms.RandomHorizontalFlip())
        if do_transform:
            transform.append(transforms.ToTensor())
        if do_normalize:
            transform.append(
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        self.transform = transforms.Compose(transform)

        with open(attr_path) as f:
            # discard the top line
            f.readline()
            self.df = pd.read_csv(f, delim_whitespace=True)

    def pos_count(self, cls_name):
        return (self.df[cls_name] == 1).sum()

    def neg_count(self, cls_name):
        return (self.df[cls_name] == -1).sum()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_name = row.name
        img_idx, ext = img_name.split('.')
        img = self.data[img_idx]

        labels = [0] * len(self.id_to_cls)
        for k, v in row.items():
            labels[self.cls_to_id[k]] = int(v)

        if self.transform is not None:
            img = self.transform(img)
        return {'img': img, 'index': index, 'labels': torch.tensor(labels)}


class CelebHQAttrFewshotDataset(Dataset):
    def __init__(self,
                 cls_name,
                 K,
                 path,
                 image_size,
                 original_resolution=256,
                 do_augment: bool = False,
                 do_transform: bool = True,
                 do_normalize: bool = True):
        super().__init__()
        self.image_size = image_size
        self.cls_name = cls_name
        self.K = K
        self.data = BaseLMDB(path, original_resolution, zfill=5)

        transform = [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
        ]
        if do_augment:
            transform.append(transforms.RandomHorizontalFlip())
        if do_transform:
            transform.append(transforms.ToTensor())
        if do_normalize:
            transform.append(
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        self.transform = transforms.Compose(transform)

        self.df = pd.read_csv(f'data/celebahq_fewshots/K{K}_{cls_name}.csv',
                              index_col=0)

    def pos_count(self, cls_name):
        return (self.df[cls_name] == 1).sum()

    def neg_count(self, cls_name):
        return (self.df[cls_name] == -1).sum()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_name = row.name
        img_idx, ext = img_name.split('.')
        img = self.data[img_idx]

        # (1, 1)
        label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1)

        if self.transform is not None:
            img = self.transform(img)

        return {'img': img, 'index': index, 'labels': label}


class Repeat(Dataset):
    def __init__(self, dataset, new_len) -> None:
        super().__init__()
        self.dataset = dataset
        self.original_len = len(dataset)
        self.new_len = new_len

    def __len__(self):
        return self.new_len

    def __getitem__(self, index):
        index = index % self.original_len
        return self.dataset[index]
