from pathlib import Path
import time

import cv2
import lmdb
import numpy as np
import torch as pt
import torch.utils.data as ptud

from .dataset import compress, decompress
from .utils import (
    color_segment_to_index_segment_and_bbox,
    even_resize_and_center_crop,
    normaliz_for_visualiz,
    draw_segmentation_np,
)


class MSCOCO(ptud.Dataset):
    """Common Objects in COntext  https://cocodataset.org/#download
    - 2017 Train images [118K/18GB]
    - 2017 Val images [5K/1GB]
    - 2017 Panoptic Train/Val annotations [821MB]

    train: {1: 977, 2: 3962, 3: 6739, 4: 6063, 5: 4330, 6: 2627, 7: 1359, 8: 678, 9: 274, 10: 107}
    val: {1: 28, 2: 182, 3: 301, 4: 230, 5: 186, 6: 116, 7: 41, 8: 31, 9: 6, 10: 6}
    """

    def __init__(
        self, data_file, transform=lambda **_: _, max_spare=4, base_dir: Path = None
    ):
        if base_dir:
            data_file = base_dir / data_file
        self.env = lmdb.open(
            str(data_file),
            subdir=False,
            readonly=True,
            readahead=False,
            meminit=False,
            max_spare_txns=max_spare,
            lock=False,
        )
        with self.env.begin(write=False) as txn:
            self.keys = decompress(txn.get(b"__keys__"))
        self.transform = transform

    def __getitem__(self, index):
        """
        image: in shape (c=3,h,w), float32
        bbox: in shape (n,c=4), float32, normalized, background excluded
        smask: in shape (n,), bool, masking slots
        segment: in shape (h,w), uint8
        """
        with self.env.begin(write=False) as txn:
            sample = decompress(txn.get(self.keys[index]))
        n, c = sample["bbox"].shape
        sample = dict(
            image=pt.from_numpy(sample["image"]).permute(2, 0, 1),
            bbox=pt.from_numpy(sample["bbox"]),
            smask=pt.ones([n], dtype=pt.bool),
            segment=pt.from_numpy(sample["segment"]),
        )
        sample2 = self.transform(**sample)
        return sample2

    def __len__(self):
        return len(self.keys)

    @staticmethod
    def convert_dataset(
        src_dir=Path("/media/GeneralZ/Storage/Static/datasets/MSCOCO"),
        dst_dir=Path("coco"),
        resolut=(128, 128),
        max_nobj=10,  # images containing too many objects will be discarded
        min_area=250,  # images containing smallest area less than this will be discarded
    ):
        """
        Structure dataset as follows and run it!
        - annotations
          - panoptic_train2017
            - *.png
          - panoptic_val2017
            - *.png
        - tain2017
          - *.jpg
        - val2017
          - *.jpg
        """
        dst_dir.mkdir(parents=True, exist_ok=True)
        info_file = dst_dir / f"{resolut[0]}_{resolut[1]}-{max_nobj}"
        info_file.touch()
        assert resolut[0] == resolut[1]
        side = resolut[0]

        splits = dict(
            train=["train2017", "annotations/panoptic_train2017"],
            val=["val2017", "annotations/panoptic_val2017"],
        )

        for split, (image_dn, segment_dn) in splits.items():
            print(split, image_dn, segment_dn)

            image_path = src_dir / image_dn
            segment_path = src_dir / segment_dn
            lmdb_file = dst_dir / f"{split}.lmdb"
            lmdb_env = lmdb.open(
                str(lmdb_file),
                map_size=1024**4,
                subdir=False,
                readonly=False,
                meminit=False,
            )

            keys = []
            txn = lmdb_env.begin(write=True)
            t0 = time.time()

            image_files = list(image_path.iterdir())
            image_files.sort()
            segment_files = list(segment_path.iterdir())
            segment_files.sort()
            cnt = 0
            for i, (image_file, segment_file) in enumerate(
                zip(image_files, segment_files)
            ):
                assert image_file.name[:-3] == segment_file.name[:-3]

                image = cv2.imread(str(image_file))
                image = even_resize_and_center_crop(image, side)

                segment = cv2.imread(str(segment_file))
                segment = even_resize_and_center_crop(
                    segment, side, cv2.INTER_NEAREST_EXACT
                )
                segment, bbox = color_segment_to_index_segment_and_bbox(segment)

                if bbox.shape[0] > max_nobj:
                    continue
                adict = dict(zip(*np.unique(segment, return_counts=True)))
                adict.pop(0)  # remove background
                if len(adict) == 0 or np.min(list(adict.values())) < min_area:
                    continue
                cnt += 1

                # __class__.visualiz(image, bbox / side, segment, 0)

                sample_key = f"{cnt:06d}".encode("ascii")
                keys.append(sample_key)

                sample = dict(
                    image=image,  # (h,w,c=3)
                    bbox=bbox / side,  # (n,c=4)
                    segment=segment,  # (h,w)
                )
                txn.put(sample_key, compress(sample))

                if cnt % 64 == 0:  # write_freq
                    print(f"{cnt:06d}, {i}")
                    txn.commit()
                    txn = lmdb_env.begin(write=True)

            txn.commit()
            print((time.time() - t0) / (cnt))

            txn = lmdb_env.begin(write=True)
            txn.put(b"__keys__", compress(keys))
            txn.commit()
            lmdb_env.close()

    @staticmethod
    def filter_bbox(bbox, min_bbox):
        """
        bbox: in shape (n,c), ltrb
        """
        area = (bbox[:, 2] - bbox[:, 0] + 1) * (bbox[:, 3] - bbox[:, 1] + 1)
        assert np.all(area) >= 1
        bbox2 = bbox[area >= min_bbox]
        return bbox2

    @staticmethod
    def visualiz(image, bbox=None, segment=None, wait=0):
        if isinstance(image, pt.Tensor):
            image = image.permute(1, 2, 0).cpu().contiguous().numpy()
        image = normaliz_for_visualiz(image)
        image = np.clip(image * 127.5 + 127.5, 0, 255).astype("uint8")

        if bbox is not None:
            if isinstance(bbox, pt.Tensor):
                bbox = bbox.cpu().numpy()
            bbox[:, 0::2] *= image.shape[1]
            bbox[:, 1::2] *= image.shape[0]
            for box in bbox.astype("int"):
                image = cv2.rectangle(image, box[:2], box[2:], (0, 0, 0))

        cv2.imshow("i", image)

        if segment is not None:
            if isinstance(segment, pt.Tensor):
                segment = segment.cpu().numpy()
            segment = draw_segmentation_np(image, segment, 0.6)
            cv2.imshow("s", segment)

        cv2.waitKey(wait)
        return image, segment

        # for b in bbox:
        #     b = b.astype("int")
        #     cv2.rectangle(image, b[:2], b[2:], (63, 127, 255))
        # cv2.imshow("i", image)
        # if segment != []:
        #     if isinstance(segment, pt.Tensor):
        #         segment = segment.cpu().numpy()
        #     segment = normaliz_for_visualiz(segment)
        #     cv2.imshow("s", segment)
        #     # cv2.waitKey(wait)
        # cv2.waitKey(wait)
