from pathlib import Path
import time

import cv2
import lmdb
import numpy as np
import torch as pt
import torch.utils.data as ptud

from .dataset import compress, decompress
from .utils import (
    color_segment_to_index_segment_and_bbox,
    even_resize_and_center_crop,
    normaliz_for_visualiz,
    draw_segmentation_np,
)


class ClevrTex(ptud.Dataset):
    """ClevrTex: A Texture-Rich Benchmark for Unsupervised Multi-Object Segmentation
    https://www.robots.ox.ac.uk/~vgg/data/clevrtex, Original Version
    - ClevrTex (part 1, 4.7 GB)
    - ClevrTex (part 2, 4.7 GB)
    - ClevrTex (part 3, 4.7 GB)
    - ClevrTex (part 4, 4.7 GB)
    - ClevrTex (part 5, 4.7 GB)
    Also its variant
    - ClevrTex-OOD test set (5.3 GB)

    Number of objects distribution:
    - train: {3: 5463, 4: 5507, 5: 5464, 6: 5505, 7: 5481, 8: 5377, 9: 5414, 10: 5539}
    - val: {3: 740, 4: 790, 5: 802, 6: 777, 7: 813, 8: 790, 9: 778, 10: 760}
    """

    IDXS = {
        "train": [_ for _ in range(50000) if _ % 8 != 0],
        "val": [_ for _ in range(50000) if _ % 8 == 0],
        "ood": [_ for _ in range(10000)],
    }

    def __init__(
        self,
        data_file,
        split,
        transform=lambda **_: _,
        max_spare=4,
        base_dir: Path = None,
    ):
        if base_dir:
            data_file = base_dir / data_file
        self.env = lmdb.open(
            str(data_file),
            subdir=False,
            readonly=True,
            readahead=False,
            meminit=False,
            max_spare_txns=max_spare,
            lock=False,
        )
        with self.env.begin(write=False) as txn:
            self_keys = decompress(txn.get(b"__keys__"))
        self.keys = [self_keys[_] for _ in self.IDXS[split]]
        self.transform = transform

    def __getitem__(self, index):
        """
        - image: in shape (c=3,h,w), float32
        - bbox: in shape (n,c=4), float32
        - smask: in shape (n,), bool, masking slots
        - segment: in shape (h,w), uint8
        - depth: in shape (h,w), uint8
        """
        with self.env.begin(write=False) as txn:
            sample = decompress(txn.get(self.keys[index]))
        n, c = sample["bbox"].shape
        sample = dict(
            image=pt.from_numpy(sample["image"]).permute(2, 0, 1),
            bbox=pt.from_numpy(sample["bbox"]),
            smask=pt.ones([n], dtype=pt.bool),
            segment=pt.from_numpy(sample["segment"]),
            depth=pt.from_numpy(sample["depth"]),
        )
        sample2 = self.transform(**sample)
        return sample2

    def __len__(self):
        return len(self.keys)

    @staticmethod
    def convert_dataset(
        src_dir=Path("/media/GeneralZ/Storage/Static/datasets/clevrtex-original"),
        dst_dir=Path("clevrtex"),
        resolut=(128, 128),  # spatial downsample: (240,320)->(h,w)
    ):
        """
        Structure dataset as follows and run it!
        - 0
          - *.png
        ...
        - 49
          - *.png
        """
        dst_dir.mkdir(parents=True, exist_ok=True)
        info_file = dst_dir / f"{resolut[0]}_{resolut[1]}"
        info_file.touch()
        assert resolut[0] == resolut[1]
        side = resolut[0]

        dst_file = dst_dir / "data.lmdb"
        lmdb_env = lmdb.open(
            str(dst_file),
            map_size=1024**4,
            subdir=False,
            readonly=False,
            meminit=False,
        )
        keys = []
        txn = lmdb_env.begin(write=True)

        t0 = time.time()

        scenes = list(src_dir.glob("**/*.png"))
        scenes.sort()

        total_num = len(scenes) // 6
        assert total_num in [10000, 50000]

        for cnt in range(total_num):
            files = scenes[cnt * 6 : cnt * 6 + 6]
            assert files[0].name.split(".")[0].split("_")[2].isnumeric()
            image_file = str(files[0])
            assert files[3].name.endswith("_flat.png")
            flat_file = str(files[3])
            assert files[2].name.endswith("_depth_0001.png")
            depth_file = str(files[2])

            image = cv2.imread(image_file)
            image = even_resize_and_center_crop(image, side)

            segment = cv2.imread(flat_file)
            segment = even_resize_and_center_crop(
                segment, side, cv2.INTER_NEAREST_EXACT
            )
            segment, bbox = color_segment_to_index_segment_and_bbox(segment)

            depth = cv2.imread(depth_file)  # [:, :, 0]
            depth = even_resize_and_center_crop(depth, side)[:, :, 0]

            # __class__.visualiz(image, bbox, segment, depth, 0)

            sample_key = f"{cnt:06d}".encode("ascii")
            keys.append(sample_key)

            sample_dict = dict(
                image=image,  # (h,w,c=3)
                bbox=bbox / side,  # (n,c=4)
                segment=segment,  # (h,w)
                depth=depth,  # (h,w)
            )
            txn.put(sample_key, compress(sample_dict))

            if (cnt + 1) % 64 == 0:  # write_freq
                print(f"{(cnt+1):06d}")
                txn.commit()
                txn = lmdb_env.begin(write=True)

        txn.commit()
        txn = lmdb_env.begin(write=True)
        txn.put(b"__keys__", compress(keys))
        txn.commit()
        lmdb_env.close()

        print(f"total={cnt + 1}, time={time.time() - t0}")

    @staticmethod
    def visualiz(image, bbox=None, segment=None, depth=None, wait=0):
        if isinstance(image, pt.Tensor):
            image = image.permute(1, 2, 0).cpu().contiguous().numpy()
        image = normaliz_for_visualiz(image)
        image = np.clip(image * 127.5 + 127.5, 0, 255).astype("uint8")

        if bbox is not None:
            if isinstance(bbox, pt.Tensor):
                bbox = bbox.cpu().numpy()
            bbox[:, 0::2] *= image.shape[1]
            bbox[:, 1::2] *= image.shape[0]
            for box in bbox.astype("int"):
                image = cv2.rectangle(image, box[:2], box[2:], (63, 127, 255))

        cv2.imshow("i", image)

        if segment is not None:
            if isinstance(segment, pt.Tensor):
                segment = segment.cpu().numpy()
            segment = draw_segmentation_np(image, segment, 0.6)
            cv2.imshow("s", segment)

        if depth is not None:
            if isinstance(depth, pt.Tensor):
                depth = depth.cpu().numpy()
            cv2.imshow("d", depth)

        cv2.waitKey(wait)
        return image, segment
