import time
import gc
import os
import glob
import numpy as np
import torch
from copy import deepcopy
from torch.utils.data import Dataset
from collections.abc import Sequence
import shutil

# import camtools as ct
# import open3d as o3d
# import matplotlib.pyplot as plt

from pointcept.utils.logger import get_root_logger
from pointcept.utils.cache import shared_dict
from .builder import DATASETS
from .defaults import DefaultDataset
from .transform import Compose, TRANSFORMS
from PIL import Image
from io import BytesIO
from torchvision.transforms import transforms as T
from torchvision.utils import save_image
import torch.nn.functional as F
import trimesh
from pathlib import Path

torch.set_printoptions(threshold=torch.inf)
np.set_printoptions(threshold=np.inf)


def convert_image(image):
    image = Image.open(BytesIO(np.array(image).tobytes()))
    return image


def resize_images_intrinsic(images, intrinsics, size, crop_size, _alignment):
    h, w = size
    h0, w0 = images[0].shape[1:]
    crop_h, crop_w = crop_size
    intrinsics[:, 0, 0] = crop_w
    intrinsics[:, 1, 1] = crop_h
    cam_trans = np.array(
        [[w / crop_w, 0, 0], [0, h / crop_h, 0], [0, 0, 1]], dtype=np.float32
    )
    # cam_trans = np.array([[w / w0, 0, 0], [0, h / h0, 0], [0, 0, 1]], dtype=np.float32)
    intrinsics = np.stack([cam_trans @ intrinsics_i for intrinsics_i in intrinsics])
    images = F.interpolate(images, size, mode="bilinear", align_corners=False)
    return images, intrinsics


def resize_corresponding_info(corresponding, size, size0, crop_size, _alignment):
    h, w = size
    h0, w0 = size0
    left, top, right, bottom = crop_size
    crop_h = bottom - top
    crop_w = right - left
    # print(h,w,h0,w0,crop_h,crop_w)
    mask_crop = (
        (corresponding[:, 1] >= top)
        & (corresponding[:, 1] < bottom)
        & (corresponding[:, 0] >= left)
        & (corresponding[:, 0] < right)
    )
    corresponding = corresponding[mask_crop]
    corresponding[:, 1] -= top
    corresponding[:, 0] -= left
    # corresponding[~mask_crop] = np.ones((1, 3), dtype=np.int32) * (-1)
    corresponding[:, 1] = (corresponding[:, 1] * h / crop_h // _alignment).astype(
        np.int32
    )
    corresponding[:, 0] = (corresponding[:, 0] * w / crop_w // _alignment).astype(
        np.int32
    )
    # corresponding[:, 1] = (corresponding[:, 1] * h / h0 // _alignment).astype(np.int32)
    # corresponding[:, 0] = (corresponding[:, 0] * w / w0 // _alignment).astype(np.int32)
    corresponding = corresponding[:, [1, 0, 2]]
    corresponding = np.unique(corresponding, axis=0)
    return corresponding


class GridSample(object):
    def __init__(
        self,
        grid_size=0.05,
        hash_type="fnv",
    ):
        self.grid_size = grid_size
        self.hash = self.fnv_hash_vec if hash_type == "fnv" else self.ravel_hash_vec

    def __call__(self, points, colors=None):
        scaled_coord = points / np.array(self.grid_size)
        grid_coord = np.floor(scaled_coord).astype(int)
        min_coord = grid_coord.min(0)
        grid_coord -= min_coord
        min_coord = min_coord * np.array(self.grid_size)
        key = self.hash(grid_coord)
        idx_sort = np.argsort(key)
        key_sort = key[idx_sort]
        _, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True)
        idx_select = (
            np.cumsum(np.insert(count, 0, 0)[0:-1])
            + np.random.randint(0, count.max(), count.size) % count
        )
        idx_unique = idx_sort[idx_select]
        points = points[idx_unique]
        if colors is not None:
            colors = colors[idx_unique]
            return points, colors
        return points

    @staticmethod
    def ravel_hash_vec(arr):
        """
        Ravel the coordinates after subtracting the min coordinates.
        """
        assert arr.ndim == 2
        arr = arr.copy()
        arr -= arr.min(0)
        arr = arr.astype(np.uint64, copy=False)
        arr_max = arr.max(0).astype(np.uint64) + 1

        keys = np.zeros(arr.shape[0], dtype=np.uint64)
        # Fortran style indexing
        for j in range(arr.shape[1] - 1):
            keys += arr[:, j]
            keys *= arr_max[j + 1]
        keys += arr[:, -1]
        return keys

    @staticmethod
    def fnv_hash_vec(arr):
        """
        FNV64-1A
        """
        assert arr.ndim == 2
        # Floor first for negative coordinates
        arr = arr.copy()
        arr = arr.astype(np.uint64, copy=False)
        hashed_arr = np.uint64(14695981039346656037) * np.ones(
            arr.shape[0], dtype=np.uint64
        )
        for j in range(arr.shape[1]):
            hashed_arr *= np.uint64(1099511628211)
            hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j])
        return hashed_arr


@DATASETS.register_module()
class RE10KDatasetALL_img(DefaultDataset):
    VALID_ASSETS = [
        "coord",
        "color",
        "extrinsics",
        "intrinsics",
        "extrinsics_pred",
        "intrinsics_pred",
        "mask",
        "normal",
        # "segment200",
        # "instance",
    ]
    # IMG_NUM = 1
    # IMG_NUM = 2
    IMG_NUM = 4
    gridsample = GridSample(grid_size=0.01)

    def __init__(
        self,
        split="train",
        data_root="data/dataset",
        transform=None,
        test_mode=False,
        test_cfg=None,
        cache=False,
        ignore_index=-1,
        loop=1,
        img_ratio=1,
        crop_h=630,
        crop_w=1120,
        patch_size=14,
    ):
        super().__init__(
            split,
            data_root,
            transform,
            test_mode,
            test_cfg,
            cache,
            ignore_index,
            loop,
            img_ratio,
        )
        self.crop_h = crop_h
        self.crop_w = crop_w
        self.patch_size = patch_size
        self.patch_h = crop_h // patch_size
        self.patch_w = crop_w // patch_size
        self.transform_img = T.Compose(
            [
                # T.ToPILImage(),
                T.Resize(
                    (self.patch_h * self.patch_size, self.patch_w * self.patch_size)
                ),
                T.ToTensor(),
                # T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ]
        )

    def get_data_list(self):
        # if self.lr is None:
        #     data_list = super().get_data_list()
        # else:
        #     data_list = [
        #         os.path.join(self.data_root, "train", name) for name in self.lr
        #     ]
        data_list = super().get_data_list()  # [18720:]
        return data_list

    def extract_string(self, input_string):
        last_underscore_index = input_string.rfind("_")
        result = input_string[:last_underscore_index]
        return result

    def get_data(self, idx):
        # time1 = time.time()
        data_path = self.data_list[idx % len(self.data_list)]
        name = self.get_data_name(idx)
        # if self.cache:
        #     cache_name = f"pointcept-{name}"
        #     return shared_dict(cache_name)

        data_dict = {}
        assets = os.listdir(data_path)
        jpgs = [asset for asset in assets if asset.endswith(".jpg")]
        pngs = [asset for asset in assets if asset.endswith(".png")]
        jpg_pngs = jpgs + pngs
        jpg_pngs.sort(key=lambda x: int(x.split(".")[0]))
        imgs = [Image.open(os.path.join(data_path, asset)) for asset in jpg_pngs]
        # if len(imgs)==0:
        #     shutil.move(data_path, os.path.join(self.data_root, "failure"))
        #     self.data_list.remove(data_path)
        #     return self.get_data(idx + 1)
        img_width, img_height = imgs[0].size
        div_w = img_width // self.patch_w
        div_h = img_height // self.patch_h
        div_min = max(min(div_w, div_h), 1)
        crop_img_width = div_min * self.patch_w
        crop_img_height = div_min * self.patch_h
        left = int((img_width - crop_img_width) / 2)
        top = int((img_height - crop_img_height) / 2)
        right = int((img_width + crop_img_width) / 2)
        bottom = int((img_height + crop_img_height) / 2)
        imgs = [img.crop((left, top, right, bottom)) for img in imgs]
        imgs = [self.transform_img(img) for img in imgs][: self.IMG_NUM]
        try:
            data_dict["coord"] = np.load(os.path.join(data_path, "coord.npy"))
            data_dict["color"] = np.load(os.path.join(data_path, "color.npy"))
            data_dict["normal"] = np.load(os.path.join(data_path, "normal.npy"))
        except:
            shutil.move(data_path, os.path.join(self.data_root, "failure"))
            self.data_list.remove(data_path)
            return self.get_data(idx + 1)
        color = data_dict["color"]
        coord = data_dict["coord"]
        if np.max(color) <= 1.5:
            color = color * 255
            color = np.clip(color, 0.0, 255.0)

        corresponding_path = os.path.join(data_path, "corresponding")
        corresponding_assets = list(os.listdir(corresponding_path))
        corresponding_assets.sort(key=lambda x: int(x.split(".")[0]))
        corresponding_infos = np.ones(
            (coord.shape[0], len(corresponding_assets), 2), dtype=np.int32
        ) * (-1)
        for asset_id, asset in enumerate(corresponding_assets):
            try:
                corresponding_info = np.load(
                    os.path.join(corresponding_path, asset)
                ).astype(np.int32)
            except:
                shutil.move(data_path, os.path.join(self.data_root, "failure"))
                self.data_list.remove(data_path)
                return self.get_data(idx + 1)
            # if corresponding_info.shape[0] == 1:
            if np.array_equal(corresponding_info, -np.ones((1, 3))):
                continue
            corresponding_info = resize_corresponding_info(
                corresponding_info,
                (self.patch_h * self.patch_size, self.patch_w * self.patch_size),
                (img_height, img_width),
                (left, top, right, bottom),
                self.patch_size,
            )
            # if corresponding_info is None:
            #     shutil.move(data_path, os.path.join(self.data_root, "failure"))
            #     self.data_list.remove(data_path)
            #     return self.get_data(idx + 1)
            corresponding_infos[corresponding_info[:, -1], asset_id, :] = (
                corresponding_info[:, :-1]
            )

        imgs_list = torch.stack(imgs)
        data_dict["imgs"] = imgs_list.float()

        data_dict["img_num"] = np.array([len(corresponding_assets)], dtype=np.int32)
        data_dict["mask_index"] = corresponding_infos  # .reshape(-1, 2)

        data_dict["name"] = name
        data_dict["coord"] = coord.astype(np.float32)
        data_dict["color"] = color.astype(np.float32)

        if "normal" in data_dict.keys():
            data_dict["normal"] = data_dict["normal"].astype(np.float32)
        if "extrinsics" in data_dict.keys():
            data_dict["extrinsics"] = data_dict["extrinsics"].astype(np.float32)
        if "intrinsics" in data_dict.keys():
            data_dict["intrinsics"] = data_dict["intrinsics"].astype(np.float32)
        if "mask" in data_dict.keys():
            data_dict["mask"] = data_dict["mask"].astype(np.float32)
        if "segment" in data_dict.keys():
            data_dict["segment"] = data_dict["segment"].reshape([-1]).astype(np.int32)
        else:
            data_dict["segment"] = (
                np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1
            )

        if "instance" in data_dict.keys():
            data_dict["instance"] = data_dict["instance"].reshape([-1]).astype(np.int32)
        else:
            data_dict["instance"] = (
                np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1
            )

        if data_dict["coord"].shape[0] < 10:
            # return None
            shutil.move(data_path, os.path.join(self.data_root, "failure"))
            self.data_list.remove(data_path)
            return self.get_data(idx + 1)
        return data_dict
