import os
import random
import pandas as pd
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import itertools

from torchvision.transforms import transforms

import ldm.util

def img_preprocess(input_im, carvekit_model):
    # carvekit_model = create_carvekit_interface()
    input_im = ldm.util.load_and_preprocess(carvekit_model, input_im)
    input_im = (input_im / 255.0).astype(np.float32)
    # print(input_im.shape)

    input_im = transforms.ToTensor()(input_im)
    input_im = input_im * 2 - 1
    # print('input image shape', input_im.shape)
    return input_im



class CommonObjectDataset(Dataset):
    def __init__(self, image_dir, csv_path, preprocess_fn, carvekit_model, transform=None):
        self.image_dir = image_dir
        self.image_files = sorted(os.listdir(image_dir))  # 加载所有图片
        self.csv_data = pd.read_csv(csv_path, header=None)  # 加载csv文件，无表头
        self.transform = transform
        self.image_index = 0        # 初始化图片索引
        self.carvekit_model = carvekit_model
        self.preprocess_fn = preprocess_fn

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        if isinstance(idx, list):
            images, angles = [], []
            for i in idx:
                image_path = os.path.join(self.image_dir, self.image_files[i])
                image = Image.open(image_path).convert("RGB")
                image = self.preprocess_fn(image, self.carvekit_model)
                csv_row = self.csv_data.sample(n=1).values.flatten()
                angle = csv_row.tolist()
                if self.transform:
                    image = self.transform(image)
                images.append(image)
                angles.append(angle)
            images = torch.stack(images, dim=0)  # [B, 3, H, W]
            angles = torch.tensor(angles, dtype=torch.float32)  # [B, 3]
            return images, angles
        # 加载图片
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(image_path).convert("RGB")
        image = self.preprocess_fn(image, self.carvekit_model)

        # 随机选择CSV中的一行数据
        csv_row = self.csv_data.sample(n=1).values.flatten()
        angle = csv_row.tolist()
        angle = torch.tensor(angle, dtype=torch.float32)

        # 应用数据变换（如果有）
        if self.transform:
            image = self.transform(image)

        return image, angle

    # def __iter__(self):
    #     self.image_index = 0  # 重置图片索引
    #     return self

    # def __next__(self):
    #     if self.image_index >= len(self.image_files):
    #         self.image_index = 0  # 到达末尾时重置索引，实现循环
    #
    #     # 加载图片
    #     image_path = os.path.join(self.image_dir, self.image_files[self.image_index])
    #     image = Image.open(image_path).convert("RGBA")
    #     self.image_index += 1  # 更新图片索引
    #
    #     # 随机选择CSV中的一行数据
    #     csv_row = self.csv_data.sample(n=1).values.flatten()
    #     angle = csv_row.tolist()
    #
    #     # 应用数据变换（如果有）
    #     if self.transform:
    #         image = self.transform(image)
    #
    #     return image, angle


class ObjectTOForgetAngleDataset(Dataset):
    def __init__(self, forget_image_path, override_image_path, angles_csv_path, preprocess_fn, carvekit_model, transform=None, val=False):
        self.forget_image = Image.open(forget_image_path).convert("RGBA")
        self.override_image = Image.open(override_image_path).convert("RGBA")
        self.angles_csv_data = pd.read_csv(angles_csv_path, header=None)  # 加载csv文件，无表头
        self.transform = transform
        self.index = 0
        self.val = val
        self.carvekit_model = carvekit_model
        self.preprocess_fn = preprocess_fn


    def __len__(self):
        return len(self.angles_csv_data)

    def __getitem__(self, idx):
        if self.val:
            csv_row = self.angles_csv_data.iloc[idx].values.flatten()
            angle = csv_row.tolist()
            return self.preprocess_fn(self.forget_image, self.carvekit_model), angle

        if isinstance(idx, list):
            forget_images, override_images, angles = [], [], []
            for i in idx:
                csv_row = self.angles_csv_data.iloc[i].values.flatten()
                angle = csv_row.tolist()
                forget_image = self.preprocess_fn(self.forget_image.copy(), self.carvekit_model)
                override_image = self.preprocess_fn(self.override_image.copy(), self.carvekit_model)
                if self.transform:
                    forget_image = self.transform(forget_image)
                    override_image = self.transform(override_image)
                forget_images.append(forget_image)
                override_images.append(override_image)
                angles.append(angle)
            forget_images = torch.stack(forget_images, dim=0)  # [B, 3, H, W]
            override_images = torch.stack(override_images, dim=0)  # [B, 3, H, W]
            angles = torch.tensor(angles, dtype=torch.float32)  # [B, 3]

            return forget_images, override_images, angles

        # 随机选择CSV中的一行数据
        csv_row = self.angles_csv_data.iloc[idx].values.flatten()
        angle = csv_row.tolist()
        angle = torch.tensor(angle, dtype=torch.float32)

        forget_image = self.preprocess_fn(self.forget_image.copy(), self.carvekit_model)
        override_image = self.preprocess_fn(self.override_image.copy(), self.carvekit_model)
        # 应用数据变换（如果有）
        if self.transform:
            override_image = self.transform(override_image)
            forget_image = self.transform(forget_image)

        return forget_image, override_image, angle

    # def __iter__(self):
    #     self.index = 0  # 重置索引
    #     return self
    #
    # def __next__(self):
    #     # 如果索引达到数据集末尾，则重置为0，实现循环
    #     if self.index >= len(self.angles_csv_data):
    #         self.index = 0
    #
    #     # 获取当前索引的数据
    #     csv_row = self.angles_csv_data.iloc[self.index].values.flatten()
    #     angle = csv_row.tolist()
    #
    #     # 获取图片副本并应用变换（如果有）
    #     override_image = self.override_image.copy()
    #     forget_image = self.forget_image.copy()
    #     # 应用数据变换（如果有）
    #     if self.transform:
    #         override_image = self.transform(override_image)
    #         forget_image = self.transform(forget_image)
    #
    #     # 更新索引并返回
    #     self.index += 1
    #     return forget_image, override_image, angle


# class ObjectTOForgetAngleDataset(Dataset):
#     def __init__(self, image_path, forget_csv_path, override_csv_path, transform=None, val=False):
#         self.image = Image.open(image_path).convert("RGBA")
#         self.forget_csv_data = pd.read_csv(forget_csv_path, header=None)  # 加载csv文件，无表头
#         self.override_csv_data = pd.read_csv(override_csv_path, header=None)
#         assert len(self.forget_csv_data) == len(self.override_csv_data)
#         self.transform = transform
#         self.index = 0
#         self.val = val
#
#     def __len__(self):
#         return len(self.forget_csv_data)
#
#     def __getitem__(self, idx):
#         if self.val:
#             forget_csv_row = self.forget_csv_data.iloc[idx].values.flatten()
#             forget_angle = forget_csv_row.tolist()
#             return self.image, forget_angle
#         # 随机选择CSV中的一行数据
#         forget_csv_row = self.forget_csv_data.iloc[idx].values.flatten()
#         override_csv_row = self.override_csv_data.iloc[idx].values.flatten()
#         forget_angle = forget_csv_row.tolist()
#         override_angle = override_csv_row.tolist()
#
#         image = self.image.copy()
#         # 应用数据变换（如果有）
#         if self.transform:
#             image = self.transform(self.image)
#
#         return image, forget_angle, override_angle
#
#     def __iter__(self):
#         self.index = 0  # 重置索引
#         return self
#
#     def __next__(self):
#         # 如果索引达到数据集末尾，则重置为0，实现循环
#         if self.index >= len(self.forget_csv_data):
#             self.index = 0
#
#         # 获取当前索引的数据
#         forget_csv_row = self.forget_csv_data.iloc[self.index].values.flatten()
#         override_csv_row = self.override_csv_data.iloc[self.index].values.flatten()
#         forget_angle = forget_csv_row.tolist()
#         override_angle = override_csv_row.tolist()
#
#         # 获取图片副本并应用变换（如果有）
#         image = self.image.copy()
#         if self.transform:
#             image = self.transform(image)
#
#         # 更新索引并返回
#         self.index += 1
#         return image, forget_angle, override_angle


if __name__ == '__main__':

    carvekit_model = ldm.util.create_carvekit_interface()

    commonObjDataset = CommonObjectDataset(image_dir='../zero123_dataset/common_objects/image',
                                           csv_path='../zero123_dataset/common_objects/cam_angle.csv',
                                           preprocess_fn=img_preprocess, carvekit_model=carvekit_model,
                                           transform=None)

    print(len(commonObjDataset))
    # commonObjIter = iter(commonObjDataset)

    obj2forgetAngleDataset = ObjectTOForgetAngleDataset(
        forget_image_path='../zero123_dataset/object_to_forget_angle/image/02.png',
        override_image_path='../zero123_dataset/object_to_forget_angle/image/real_plane.png',
        angles_csv_path='../zero123_dataset/object_to_forget_angle/fixed_views_with_roll.csv',
        preprocess_fn=img_preprocess, carvekit_model=carvekit_model,
        transform=None
    )
    print(len(obj2forgetAngleDataset))
    # obj2forgetIter = iter(obj2forgetAngleDataset)
    # 创建 DataLoader，指定 batch_size
    common_dataloader = DataLoader(commonObjDataset, batch_size=4, shuffle=True)
    forget_dataloader = DataLoader(obj2forgetAngleDataset, batch_size=4, shuffle=True)

    # 使用 DataLoader 批量加载数据
    for images, angles in common_dataloader:
        print("Batch of images:", images.shape)
        print("Batch of angles:", angles.shape)

    for forget_images, override_images, angles in forget_dataloader:
        print("Batch of forget images:", forget_images.shape)
        print("Batch of override images:", override_images.shape)
        print("Batch of forget angles:", angles.shape)

    # for i in range(1):
    #     img, forget_angle, override_angle = next(obj2forgetIter)
    #     # if i == 0:
    #     #     img.show()
    #     print(f'forget_angle {i}: ', forget_angle)
    #     print('override_angle: ', override_angle)
    #     print(img)


