from pathlib import Path
import time

import cv2
import lmdb
import numpy as np
import torch as pt
import torch.utils.data as ptud

from .dataset import compress, decompress
from .utils import normaliz_for_visualiz, draw_segmentation_np


class MOVi(ptud.Dataset):
    """Wrap the LMDB file reading into PyTorch Dataset object,
    which is much faster than the original TFRecord version.
    As compressed 10x, the dataset can be put in RAM disk /dev/shm/ for extra speed.

    MOVi-A | Number of objects in a scene distribution:
    - train: {3: 1170, 4: 1239, 5: 1185, 6: 1242, 7: 1253, 8: 1249, 9: 1177, 10: 1188} by ``bbox.size(1)``
        {3: 1170, 4: 1239, 5: 1185, 6: 1242, 7: 1255, 8: 1247, 9: 1177, 10: 1188} by ``segment.max()``
    - val: {3: 23, 4: 40, 5: 26, 6: 35, 7: 31, 8: 33, 9: 27, 10: 35} by ``bbox.size(1)``
        the same by ``segment.max()``

    MOVi-B
    - train: {1: 1, 2: 23, 3: 1238, 4: 1283, 5: 1222, 6: 1250, 7: 1209, 8: 1201, 9: 1235, 10: 1088}
    - val: {3: 32, 4: 34, 5: 33, 6: 36, 7: 33, 8: 24, 9: 31, 10: 27}

    MOVi-C
    - train: {2: 19, 3: 1233, 4: 1215, 5: 1221, 6: 1233, 7: 1270, 8: 1274, 9: 1199, 10: 1073}
    - val: {3: 37, 4: 36, 5: 34, 6: 31, 7: 30, 8: 35, 9: 18, 10: 29}

    MOVi-D
    - train: {1: 35, 2: 5, 3: 19, 4: 48, 5: 120, 6: 237, 7: 501, 8: 842, 9: 1128, 10: 1283, 11: 1309, 12: 1234, 13: 1017, 14: 791, 15: 549, 16: 328, 17: 176, 18: 83, 19: 34, 20: 11}
    - val: {1: 2, 3: 3, 4: 1, 5: 2, 6: 5, 7: 10, 8: 22, 9: 23, 10: 30, 11: 37, 12: 37, 13: 26, 14: 19, 15: 14, 16: 11, 17: 6, 19: 2}

    MOVi-E
    - train: {1: 2, 2: 1, 4: 5, 5: 11, 6: 30, 7: 97, 8: 223, 9: 465, 10: 742, 11: 1012, 12: 1154, 13: 1219, 14: 1033, 15: 1042, 16: 860, 17: 733, 18: 520, 19: 352, 20: 160, 21: 64, 22: 22, 23: 2}
    - val: {5: 1, 7: 4, 8: 9, 9: 11, 10: 18, 11: 18, 12: 33, 13: 43, 14: 26, 15: 24, 16: 20, 17: 12, 18: 17, 19: 10, 20: 4}

    MOVi-F
    - train: {1: 2, 3: 3, 4: 4, 5: 18, 6: 74, 7: 226, 8: 433, 9: 618, 10: 685, 11: 681, 12: 652, 13: 666, 14: 535, 15: 477, 16: 335, 17: 206, 18: 92, 19: 28, 20: 2}
    - val: {6: 5, 7: 6, 8: 11, 9: 9, 10: 17, 11: 17, 12: 16, 13: 20, 14: 15, 15: 13, 16: 10, 17: 3, 18: 4, 19: 1}

    Frame size in a scene:
    - timestep=24, height=128, width=128, channel=3.

    Video Statistics
    - mean: [111.58136, 110.821045, 109.13506]; 110.51248833333334
    - std: [18.089277, 17.647142, 19.824444]; 18.544110862661395
    Flow Statistics
    - mean: [241.2661, 247.95844, 247.82237]; 245.68230333333335
    - std: [52.802628, 36.19996 , 36.595848]; 42.57468229886195
    """

    def __init__(
        self, data_file, transform=lambda **_: _, max_spare=4, base_dir: Path = None
    ):
        if base_dir:
            data_file = base_dir / data_file
        self.env = lmdb.open(
            str(data_file),
            subdir=False,
            readonly=True,
            readahead=False,
            meminit=False,
            max_spare_txns=max_spare,
            lock=False,
        )
        with self.env.begin(write=False) as txn:
            self.keys = decompress(txn.get(b"__keys__"))
        self.transform = transform

    def __getitem__(self, index):
        """
        video: in shape (t,c=3,h,w), float32
        bbox: in shape (t,n,c=4), float32, normalized, only foreground objects
        smask: in shape (t,n), bool
        flow: in shape (t,c=3,h,w), float32
        depth: in shape (t,c=1,h,w), float32
        segment: in shape (t,c=1,h,w), uint8
        """
        with self.env.begin(write=False) as txn:
            sample = decompress(txn.get(self.keys[index]))
        t, n, c = sample["bbox"].shape
        sample = dict(
            video=pt.from_numpy(sample["video"]).permute(0, 3, 1, 2),  # (t,c,h,w) uint8
            bbox=pt.from_numpy(sample["bbox"]),  # (t,n,c), float32
            smask=pt.ones([t, n], dtype=pt.bool),
            flow=pt.from_numpy(sample["flow"]).permute(0, 3, 1, 2),  # (t,c,h,w), uint8
            depth=pt.from_numpy(sample["depth"]),  # (t,h,w), float32
            segment=pt.from_numpy(sample["segment"]),  # (t,h,w), uint8
        )
        sample2 = self.transform(**sample)
        return sample2

    def __len__(self):
        return len(self.keys)

    @staticmethod
    def convert_dataset(
        src_dir="/media/GeneralZ/Storage/Static/datasets/tfds",
        tfds_name="movi_f/128x128:1.0.0",
        dst_dir=Path("movi_f"),
    ):
        """
        Convert the original TFRecord files into one LMDB file, saving 10x storage space.

        Note: This requires the following TensorFlow-series libs, which could mess up your environment,
        so to run this part you had better just install them soley in a separate environment.
        ```
        clu==0.0.10
        tensorflow_cpu
        tensorflow_datasets
        ```

        Download MOVi series datasets. Remember to install gsutil first https://cloud.google.com/storage/docs/gsutil_install
        ```bash
        cd local/path/to/movi_a/
        gsutil -m cp -r gs://kubric-public/tfds/movi_a/128x128/1.0.0 .
        # download movi_b, c, d, e, f in the similar way if needed
        ```

        Finally create a Python script with the following content at the project root, and execute it:
        ```python
        from object_centric_bench.datum import MOVi
        MOVi.convert_dataset()  # remember to change default paths to yours
        ```
        """
        dst_dir.mkdir(parents=True, exist_ok=True)

        from clu import deterministic_data
        import tensorflow as tf
        import tensorflow.python.framework.ops as tfpfo
        import tensorflow_datasets as tfds

        _gpus = tf.config.list_physical_devices("GPU")
        [tf.config.experimental.set_memory_growth(_, True) for _ in _gpus]

        for split in ["train", "validation"]:
            print(split)

            dataset_builder = tfds.builder(tfds_name, data_dir=src_dir)
            dataset_split = deterministic_data.get_read_instruction_for_host(
                split,
                dataset_builder.info.splits[split].num_examples,
            )
            dataset = deterministic_data.create_dataset(
                dataset_builder,
                split=dataset_split,
                batch_dims=(),
                num_epochs=1,
                shuffle=False,
            )

            lmdb_file = dst_dir / f"{split}.lmdb"
            lmdb_env = lmdb.open(
                str(lmdb_file),
                map_size=1024**4,
                subdir=False,
                readonly=False,
                meminit=False,
            )

            keys = []
            txn = lmdb_env.begin(write=True)

            t0 = time.time()

            for i, sample in enumerate(dataset):
                sample2 = __class__.convert_nested_mapping(sample, tfpfo, tf)
                sample3 = __class__.video_from_tfds(sample2)
                sample3 = __class__.bbox_sparse_to_dense(sample3)
                sample3 = __class__.unpack_uint16_to_float32(sample3)
                sample3 = __class__.flow_to_rgb(sample3)

                sample_key = f"{i:06d}".encode("ascii")
                keys.append(sample_key)
                txn.put(sample_key, compress(sample3))
                # __class__.visualiz(sample3, 0)

                if (i + 1) % 64 == 0:  # write_freq
                    print(f"{i + 1:06d}")
                    txn.commit()
                    txn = lmdb_env.begin(write=True)

            txn.commit()
            print((time.time() - t0) / (i + 1))  # 0.0298842241987586

            txn = lmdb_env.begin(write=True)
            txn.put(b"__keys__", compress(keys))
            txn.commit()
            lmdb_env.close()

    @staticmethod
    def convert_nested_mapping(mapping: dict, tfpfo, tf):
        mapping2 = {}
        for key, value in mapping.items():
            if isinstance(value, dict):
                value2 = __class__.convert_nested_mapping(value, tfpfo, tf)
            elif isinstance(value, tfpfo.EagerTensor):
                value2 = value.numpy()
            elif isinstance(value, tf.RaggedTensor):
                value2 = value.to_list()
            else:
                raise "NotImplemented"
            mapping2[key] = value2
        return mapping2

    @staticmethod
    def video_from_tfds(pack: dict) -> dict:
        """Adopted from SAVi official implementation VideoFromTfds class."""
        video = pack["video"].astype("uint8")  # (t,h,w,c)
        mask = np.ones_like(video)[..., 0]  # (t,h,w)

        track = pack["instances"]["bbox_frames"]
        bbox = pack["instances"]["bboxes"]

        flow_range = pack["metadata"]["backward_flow_range"]
        flow = pack["backward_flow"]  # 0~65535 (t,h,w,c=2)

        depth_range = pack["metadata"]["depth_range"]
        depth = pack["depth"][:, :, :, 0]  # 0~65535 (t,h,w)

        segment = pack["segmentations"][:, :, :, 0]  # (t,h,w)

        return dict(
            video=video,
            mask=mask,
            bbox=dict(track=track, bbox=bbox),
            flow=dict(data=flow, min=flow_range[0], max=flow_range[1]),
            depth=dict(data=depth, min=depth_range[0], max=depth_range[1]),
            segment=segment,
        )

    @staticmethod
    def bbox_sparse_to_dense(pack: dict, notrack=0) -> dict:  # TODO notrack=-1
        """Adopted from SAVi official implementation SparseToDenseAnnotation class."""

        def densify_bbox(tracks: list, bboxs_s: list, timestep: int):
            assert len(tracks) == len(bboxs_s)

            null_box = np.array([notrack] * 4, dtype="float32")
            bboxs_d = np.tile(null_box, [timestep, len(tracks), 1])  # (t,n,c=4)

            for i, (track, bbox_s) in enumerate(zip(tracks, bboxs_s)):
                idx = np.array(track, dtype="int64")
                value = np.array(bbox_s, dtype="float32")
                bboxs_d[idx, i] = value

            return bboxs_d  # (t,n+1,c=4)

        track = pack["bbox"]["track"]
        bbox0 = pack["bbox"]["bbox"]

        segment = pack["segment"]
        assert segment.max() <= len(track)

        bbox = densify_bbox(track, bbox0, segment.shape[0])
        pack["bbox"] = bbox

        return pack

    @staticmethod
    def unpack_uint16_to_float32(pack: dict) -> dict:
        """Adopted from SAVi official implementation VideoFromTfds class."""

        def unpack_uint16_to_float32(data, min, max):
            assert data.dtype == np.uint16
            return data.astype("float32") / 65535.0 * (max - min) + min

        for key in ["flow", "depth"]:
            pack[key] = unpack_uint16_to_float32(**pack[key])
        return pack

    @staticmethod
    def flow_to_rgb(pack: dict) -> dict:
        """Adopted from SAVi official implementation FlowToRgb class."""

        def flow_to_rgb(flow, flow_scale=50.0, hsv_scale=[180.0, 255.0, 255.0]):
            # ``torchvision.utils.flow_to_image`` got strange result
            assert flow.ndim == 4
            hypot = lambda a, b: (a**2.0 + b**2.0) ** 0.5  # sqrt(a^2 + b^2)

            flow_scale = flow_scale / hypot(*flow.shape[2:4])
            hsv_scale = np.array(hsv_scale, dtype="float32")[None, None, None]

            x, y = flow[..., 0], flow[..., 1]

            h = np.arctan2(y, x)  # motion angle
            h = (h / np.pi + 1.0) / 2.0
            s = hypot(y, x)  # motion magnitude
            s = np.clip(s * flow_scale, 0.0, 1.0)
            v = np.ones_like(h)

            hsv = np.stack([h, s, v], axis=3)
            hsv = (hsv * hsv_scale).astype("uint8")
            rgb = np.array([cv2.cvtColor(_, cv2.COLOR_HSV2RGB) for _ in hsv])

            return rgb

        pack["flow"] = flow_to_rgb(pack["flow"])
        return pack

    @staticmethod
    def visualiz(video, segment=None, wait=0):
        if isinstance(video, pt.Tensor):
            video = video.permute(0, 2, 3, 1).cpu().contiguous().numpy()
        video = np.clip(video * 127.5 + 127.5, 0, 255).astype("uint8")

        # if bbox != []:
        #     if isinstance(bbox, pt.Tensor):
        #         bbox = bbox.cpu().numpy()
        #     t, h, w, c = video.shape
        #     bbox = (bbox * w).astype("int32")

        # if flow != []:
        #     if isinstance(flow, pt.Tensor):
        #         flow = flow.permute(0, 2, 3, 1).cpu().contiguous().numpy()
        #     flow = normaliz(flow)

        # if depth != []:
        #     if isinstance(depth, pt.Tensor):
        #         depth = depth.cpu().numpy()
        #     depth = normaliz(depth)

        if segment is not None:
            if isinstance(segment, pt.Tensor):
                segment = segment.cpu().numpy()
            # segment = normaliz_for_visualiz(segment)

        # c1 = (255, 255, 255)
        imgs = []
        segs = []
        for t, img in enumerate(video):
            # if len(bbox) > 0:
            #     for b in bbox[t]:
            #         cv2.rectangle(img, b[:2][::-1], b[2:][::-1], color=c1)

            cv2.imshow("v", img)
            imgs.append(img)

            # if len(flow) > 0:
            #     cv2.imshow("f", flow[t])

            # if len(depth) > 0:
            #     cv2.imshow("d", depth[t])

            if len(segment) > 0:
                # cv2.imshow("s", segment[t])
                seg = draw_segmentation_np(img, segment[t])
                cv2.imshow("s", seg)
                segs.append(seg)

            cv2.waitKey(wait)

        return imgs, segs
