import os
import glob
from .defaults import DefaultDataset
from .builder import DATASETS
import numpy as np
import torch
from copy import deepcopy
from torch.utils.data import Dataset
from collections.abc import Sequence
import shutil

from pointcept.utils.logger import get_root_logger
from pointcept.utils.cache import shared_dict
from .builder import DATASETS
from .defaults import DefaultDataset
from .transform import Compose, TRANSFORMS
from PIL import Image
from io import BytesIO
from torchvision.transforms import transforms as T
from torchvision.utils import save_image
import torch.nn.functional as F
import trimesh
from pathlib import Path

torch.set_printoptions(threshold=torch.inf)
np.set_printoptions(threshold=np.inf)


def convert_image(image):
    image = Image.open(BytesIO(np.array(image).tobytes()))
    return image


def resize_images_intrinsic(images, intrinsics, size, crop_size):
    h, w = size
    h0, w0 = images[0].shape[1:]
    crop_h, crop_w = crop_size
    intrinsics[:, 0, 0] = crop_w
    intrinsics[:, 1, 1] = crop_h
    cam_trans = np.array(
        [[w / crop_w, 0, 0], [0, h / crop_h, 0], [0, 0, 1]], dtype=np.float32
    )
    # cam_trans = np.array([[w / w0, 0, 0], [0, h / h0, 0], [0, 0, 1]], dtype=np.float32)
    intrinsics = np.stack([cam_trans @ intrinsics_i for intrinsics_i in intrinsics])
    images = F.interpolate(images, size, mode="bilinear", align_corners=False)
    return images, intrinsics


def resize_corresponding_info(corresponding, size, size0, crop_size, _alignment):
    h, w = size
    h0, w0 = size0
    left, top, right, bottom = crop_size
    crop_h = bottom - top
    crop_w = right - left
    # print(h,w,h0,w0,crop_h,crop_w)
    mask_crop = (
        (corresponding[:, 1] >= top)
        & (corresponding[:, 1] < bottom)
        & (corresponding[:, 0] >= left)
        & (corresponding[:, 0] < right)
    )
    corresponding = corresponding[mask_crop]
    corresponding[:, 1] -= top
    corresponding[:, 0] -= left
    # corresponding[~mask_crop] = np.ones((1, 3), dtype=np.int32) * (-1)
    corresponding[:, 1] = (corresponding[:, 1] * h / crop_h // _alignment).astype(
        np.int32
    )
    corresponding[:, 0] = (corresponding[:, 0] * w / crop_w // _alignment).astype(
        np.int32
    )
    # corresponding[:, 1] = (corresponding[:, 1] * h / h0 // _alignment).astype(np.int32)
    # corresponding[:, 0] = (corresponding[:, 0] * w / w0 // _alignment).astype(np.int32)
    corresponding = corresponding[:, [1, 0, 2]]
    corresponding = np.unique(corresponding, axis=0)

    return corresponding


@DATASETS.register_module()
class S3DISDatasetALL_img(DefaultDataset):
    IMG_NUM = 4

    def __init__(
        self,
        split="train",
        data_root="data/dataset",
        transform=None,
        test_mode=False,
        test_cfg=None,
        cache=False,
        ignore_index=-1,
        loop=1,
        img_ratio=1,
        crop_h=630,
        crop_w=1120,
        patch_size=14,
    ):
        super().__init__(
            split,
            data_root,
            transform,
            test_mode,
            test_cfg,
            cache,
            ignore_index,
            loop,
            img_ratio,
        )
        self.crop_h = crop_h
        self.crop_w = crop_w
        self.patch_size = patch_size
        self.patch_h = crop_h // patch_size
        self.patch_w = crop_w // patch_size
        self.transform_img = T.Compose(
            [
                # T.ToPILImage(),
                T.Resize(
                    (self.patch_h * self.patch_size, self.patch_w * self.patch_size)
                ),
                T.ToTensor(),
                # T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        )

    def get_data_name(self, idx):
        remain, cam_name = os.path.split(self.data_list[idx % len(self.data_list)])
        remain, room_name = os.path.split(remain)
        remain, area_name = os.path.split(remain)
        return f"{area_name}-{room_name}-{cam_name}"

    def get_data_list(self):
        if isinstance(self.split, str):
            data_list = glob.glob(
                os.path.join(self.data_root, self.split, "*", "*", "*")
            )
        elif isinstance(self.split, Sequence):
            data_list = []
            for split in self.split:
                data_list += glob.glob(
                    os.path.join(self.data_root, split, "*", "*", "*")
                )
        else:
            raise NotImplementedError
        return data_list

    def extract_string(self, input_string):
        area_name, room_name, cam_name = input_string.split("-")
        return area_name, room_name

    def get_data(self, idx):
        # time1 = time.time()
        data_path = self.data_list[idx % len(self.data_list)]
        name = self.get_data_name(idx)
        # if self.cache:
        #     cache_name = f"pointcept-{name}"
        #     return shared_dict(cache_name)

        data_dict = {}
        assets = os.listdir(data_path)
        jpgs = [asset for asset in assets if asset.endswith(".jpg")]
        pngs = [asset for asset in assets if asset.endswith(".png")]
        jpg_pngs = jpgs + pngs
        jpg_pngs.sort(key=lambda x: int(x.split("_frame_")[1].split("_domain_")[0]))
        imgs = [Image.open(os.path.join(data_path, asset)) for asset in jpg_pngs]
        # if len(imgs)==0:
        #     shutil.move(data_path, os.path.join(self.data_root, "failure"))
        #     self.data_list.remove(data_path)
        #     return self.get_data(idx + 1)
        img_width, img_height = imgs[0].size
        div_w = img_width // self.patch_w
        div_h = img_height // self.patch_h
        div_min = max(min(div_w, div_h), 1)
        crop_img_width = div_min * self.patch_w
        crop_img_height = div_min * self.patch_h
        left = int((img_width - crop_img_width) / 2)
        top = int((img_height - crop_img_height) / 2)
        right = int((img_width + crop_img_width) / 2)
        bottom = int((img_height + crop_img_height) / 2)
        imgs = [img.crop((left, top, right, bottom)) for img in imgs]
        imgs = [self.transform_img(img) for img in imgs][: self.IMG_NUM]
        area_name, room_name = self.extract_string(name)
        pointclouds_path = os.path.join(
            self.data_root, "pointclouds", area_name, room_name
        )
        data_dict["coord"] = np.load(os.path.join(pointclouds_path, "coord.npy"))
        data_dict["color"] = np.load(os.path.join(pointclouds_path, "color.npy"))
        data_dict["normal"] = np.load(os.path.join(pointclouds_path, "normal.npy"))
        color = data_dict["color"]
        coord = data_dict["coord"]
        if np.max(color) <= 1:
            color = color * 255

        corresponding_path = os.path.join(data_path, "corresponding")
        corresponding_assets = list(os.listdir(corresponding_path))
        corresponding_assets.sort(
            key=lambda x: int(x.split("_frame_")[1].split("_domain_")[0])
        )
        corresponding_infos = np.ones(
            (coord.shape[0], len(corresponding_assets), 2), dtype=np.int32
        ) * (-1)
        for asset_id, asset in enumerate(corresponding_assets):
            # try:
            corresponding_info = np.load(
                os.path.join(corresponding_path, asset)
            ).astype(np.int32)
            # except:
            #     shutil.move(data_path, os.path.join(self.data_root, "failure"))
            #     self.data_list.remove(data_path)
            #     return self.get_data(idx + 1)
            # if corresponding_info.shape[0] == 1:
            if np.array_equal(corresponding_info, -np.ones((1, 3))):
                continue
            corresponding_info = resize_corresponding_info(
                corresponding_info,
                (self.patch_h * self.patch_size, self.patch_w * self.patch_size),
                (img_height, img_width),
                (left, top, right, bottom),
                self.patch_size,
            )
            # if corresponding_info is None:
            #     shutil.move(data_path, os.path.join(self.data_root, "failure"))
            #     self.data_list.remove(data_path)
            #     return self.get_data(idx + 1)
            corresponding_infos[corresponding_info[:, -1], asset_id, :] = (
                corresponding_info[:, :-1]
            )

        imgs_list = torch.stack(imgs)

        data_dict["imgs"] = imgs_list.float()

        data_dict["img_num"] = np.array([len(corresponding_assets)], dtype=np.int32)
        data_dict["mask_index"] = corresponding_infos  # .reshape(-1, 2)
        # # time4 = time.time()
        # data_dict["name"] = name
        # data_dict["coord"] = coord_filter.astype(np.float32)
        # data_dict["color"] = color_filter.astype(np.float32)

        data_dict["name"] = name
        data_dict["coord"] = coord.astype(np.float32)
        data_dict["color"] = color.astype(np.float32)

        if "normal" in data_dict.keys():
            data_dict["normal"] = data_dict["normal"].astype(np.float32)
        if "extrinsics" in data_dict.keys():
            data_dict["extrinsics"] = data_dict["extrinsics"].astype(np.float32)
        if "intrinsics" in data_dict.keys():
            data_dict["intrinsics"] = data_dict["intrinsics"].astype(np.float32)
        if "mask" in data_dict.keys():
            data_dict["mask"] = data_dict["mask"].astype(np.float32)
        if "segment" in data_dict.keys():
            data_dict["segment"] = data_dict["segment"].reshape([-1]).astype(np.int32)
        else:
            data_dict["segment"] = (
                np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1
            )

        if "instance" in data_dict.keys():
            data_dict["instance"] = data_dict["instance"].reshape([-1]).astype(np.int32)
        else:
            data_dict["instance"] = (
                np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1
            )

        if data_dict["coord"].shape[0] < 10:
            # return None
            shutil.move(data_path, os.path.join(self.data_root, "failure"))
            self.data_list.remove(data_path)
            return self.get_data(idx + 1)
        return data_dict
